Skip to content

Commit

Permalink
feat: transform engine to execute transforms (explodinggradients#1352)
Browse files Browse the repository at this point in the history
Code to orchestrate your transforms

```py
from ragas.experimental.testset.transforms import Parallel, TransformerEngine
from ragas.experimental.testset.transforms.splitters import HeadlineSplitter
from ragas.experimental.testset.transforms.extractors import SummaryExtractor, KeyphrasesExtractor, TitleExtractor, HeadlinesExtractor, EmbeddingExtractor

# define the transforms
summary_extractor = SummaryExtractor()
keyphrase_extractor = KeyphrasesExtractor()
title_extractor = TitleExtractor()
headline_extractor = HeadlinesExtractor()
embedding_extractor = EmbeddingExtractor()
headline_splitter = HeadlineSplitter()

# specify the transforms and their order to be applied
transforms = [
    headline_extractor,
    headline_splitter,
    Parallel(
        embedding_extractor,
        summary_extractor
    )
]

TransformerEngine().apply(transforms, kg)
```
  • Loading branch information
jjmachan authored Sep 26, 2024
1 parent 3940964 commit d0d33b0
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 8 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ build-backend = "setuptools.build_meta"

[tool.setuptools_scm]
write_to = "src/ragas/_version.py"

[tool.pytest.ini_options]
addopts = "-n 4"
2 changes: 1 addition & 1 deletion src/ragas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ async def sema_coro(coro):
@dataclass
class Executor:
desc: str = "Evaluating"
keep_progress_bar: bool = True
show_progress: bool = True
keep_progress_bar: bool = True
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
raise_exceptions: bool = False
run_config: t.Optional[RunConfig] = field(default=None, repr=False)
Expand Down
Empty file.
Empty file.
Empty file.
8 changes: 8 additions & 0 deletions src/ragas/experimental/testset/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base import BaseGraphTransformations
from .engine import Parallel, TransformerEngine

__all__ = [
"BaseGraphTransformations",
"Parallel",
"TransformerEngine",
]
99 changes: 92 additions & 7 deletions src/ragas/experimental/testset/transforms/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from ragas.experimental.testset.graph import KnowledgeGraph, Node, Relationship
from ragas.llms import BaseRagasLLM, llm_factory

logger = logging.getLogger(__name__)


class BaseGraphTransformations(ABC):
"""
Expand Down Expand Up @@ -46,6 +49,24 @@ def filter(self, kg: KnowledgeGraph) -> KnowledgeGraph:
"""
return kg

@abstractmethod
def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
"""
Generates a list of coroutines to be executed in sequence by the Executor. This
coroutine will, upon execution, write the transformation into the KnowledgeGraph.
Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.
Returns
-------
t.List[t.Coroutine]
A list of coroutines to be executed in parallel.
"""
pass


class Extractor(BaseGraphTransformations):
"""
Expand Down Expand Up @@ -108,6 +129,35 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
"""
pass

def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
"""
Generates a list of coroutines to be executed in parallel by the Executor.
Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.
Returns
-------
t.List[t.Coroutine]
A list of coroutines to be executed in parallel.
"""

async def apply_extract(node: Node):
property_name, property_value = await self.extract(node)
if node.get_property(property_name) is None:
node.add_property(property_name, property_value)
else:
logger.warning(
"Property '%s' already exists in node '%.6s'. Skipping!",
property_name,
node.id,
)

filtered = self.filter(kg)
return [apply_extract(node) for node in filtered.nodes]


@dataclass
class LLMBasedExtractor(Extractor):
Expand Down Expand Up @@ -173,6 +223,29 @@ async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]
"""
pass

def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
"""
Generates a list of coroutines to be executed in parallel by the Executor.
Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.
Returns
-------
t.List[t.Coroutine]
A list of coroutines to be executed in parallel.
"""

async def apply_split(node: Node):
nodes, relationships = await self.split(node)
kg.nodes.extend(nodes)
kg.relationships.extend(relationships)

filtered = self.filter(kg)
return [apply_split(node) for node in filtered.nodes]


class RelationshipBuilder(BaseGraphTransformations):
"""
Expand All @@ -181,7 +254,7 @@ class RelationshipBuilder(BaseGraphTransformations):
Methods
-------
transform(kg: KnowledgeGraph) -> t.List[Relationship]
Abstract method to transform the KnowledgeGraph by building relationships.
Transforms the KnowledgeGraph by building relationships.
"""

@abstractmethod
Expand All @@ -201,12 +274,24 @@ async def transform(self, kg: KnowledgeGraph) -> t.List[Relationship]:
"""
pass

def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
"""
Generates a list of coroutines to be executed in parallel by the Executor.
class Parallel:
def __init__(self, *transformations: BaseGraphTransformations):
self.transformations = list(transformations)
Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.
Returns
-------
t.List[t.Coroutine]
A list of coroutines to be executed in parallel.
"""

async def apply_build_relationships(kg: KnowledgeGraph):
relationships = await self.transform(kg)
kg.relationships.extend(relationships)

class Sequences:
def __init__(self, *transformations: t.Union[BaseGraphTransformations, Parallel]):
self.transformations = list(transformations)
filtered = self.filter(kg)
return [apply_build_relationships(filtered)]
112 changes: 112 additions & 0 deletions src/ragas/experimental/testset/transforms/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

import asyncio
import logging
import typing as t
from dataclasses import dataclass

from ragas.executor import as_completed, is_event_loop_running, tqdm
from ragas.experimental.testset.graph import KnowledgeGraph
from ragas.experimental.testset.transforms.base import (
BaseGraphTransformations,
)
from ragas.run_config import RunConfig

logger = logging.getLogger(__name__)


class Parallel:
def __init__(self, *transformations: BaseGraphTransformations):
self.transformations = list(transformations)

def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
coroutines = []
for transformation in self.transformations:
coroutines.extend(transformation.generate_execution_plan(kg))
return coroutines


async def run_coroutines(coroutines: t.List[t.Coroutine], desc: str, max_workers: int):
"""
Run a list of coroutines in parallel.
"""
for future in tqdm(
await as_completed(coroutines, max_workers=max_workers),
desc=desc,
total=len(coroutines),
# whether you want to keep the progress bar after completion
leave=True,
):
try:
await future
except Exception as e:
logger.error(f"unable to apply transformation: {e}")


def get_desc(transform: BaseGraphTransformations | Parallel):
if isinstance(transform, Parallel):
transform_names = [t.__class__.__name__ for t in transform.transformations]
return f"Applying [{', '.join(transform_names)}] transformations in parallel"
else:
return f"Applying {transform.__class__.__name__}"


@dataclass
class TransformerEngine:
_nest_asyncio_applied: bool = False

def _apply_nest_asyncio(self):
if is_event_loop_running():
# an event loop is running so call nested_asyncio to fix this
try:
import nest_asyncio
except ImportError:
raise ImportError(
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
)

if not self._nest_asyncio_applied:
nest_asyncio.apply()
self._nest_asyncio_applied = True

def apply(
self,
transforms: t.List[BaseGraphTransformations] | Parallel,
kg: KnowledgeGraph,
run_config: RunConfig = RunConfig(),
) -> KnowledgeGraph:
# apply nest_asyncio to fix the event loop issue in jupyter
self._apply_nest_asyncio()

# apply the transformations
# if Sequences, apply each transformation sequentially
if isinstance(transforms, t.List):
for transform in transforms:
asyncio.run(
run_coroutines(
transform.generate_execution_plan(kg),
get_desc(transform),
run_config.max_workers,
)
)
# if Parallel, collect inside it and run it all
elif isinstance(transforms, Parallel):
asyncio.run(
run_coroutines(
transforms.generate_execution_plan(kg),
get_desc(transforms),
run_config.max_workers,
)
)
else:
raise ValueError(
f"Invalid transforms type: {type(transforms)}. Expects a list of BaseGraphTransformations or a Parallel instance."
)

return kg

def rollback(
self, transforms: t.List[BaseGraphTransformations], on: KnowledgeGraph
) -> KnowledgeGraph:
# this will allow you to roll back the transformations
raise NotImplementedError
19 changes: 19 additions & 0 deletions src/ragas/experimental/testset/transforms/extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .embeddings import EmbeddingExtractor
from .llm_based import (
HeadlinesExtractor,
KeyphrasesExtractor,
SummaryExtractor,
TitleExtractor,
)
from .regex_based import emails_extractor, links_extractor, markdown_headings_extractor

__all__ = [
"emails_extractor",
"links_extractor",
"markdown_headings_extractor",
"SummaryExtractor",
"KeyphrasesExtractor",
"TitleExtractor",
"HeadlinesExtractor",
"EmbeddingExtractor",
]
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,6 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
if result is None:
return self.property_name, None
return self.property_name, result.headlines
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .headline import HeadlineSplitter

__all__ = ["HeadlineSplitter"]

0 comments on commit d0d33b0

Please sign in to comment.