Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class EventNotifier:
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
self.callbacks = callbacks

def add_callback(self, callback: EventCallbackProtocol) -> None:
self.callbacks.append(callback)

def remove_callback(self, callback: EventCallbackProtocol) -> None:
self.callbacks.remove(callback)

async def notify(self, event: Event) -> None:
await asyncio.gather(
*[c(event) for c in self.callbacks],
Expand Down Expand Up @@ -114,6 +120,17 @@ async def notify_pipeline_finished(
)
await self.notify(event)

async def notify_pipeline_failed(
self, run_id: str, message: Optional[str] = None
) -> None:
event = PipelineEvent(
event_type=EventType.PIPELINE_FAILED,
run_id=run_id,
message=message,
payload=None,
)
await self.notify(event)

async def notify_task_started(
self,
run_id: str,
Expand Down
11 changes: 3 additions & 8 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
PipelineMissingDependencyError,
PipelineStatusUpdateError,
)
from neo4j_graphrag.experimental.pipeline.notification import EventNotifier
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.types.orchestration import (
RunResult,
Expand All @@ -52,10 +51,10 @@ class Orchestrator:
(checking that all dependencies are met), and run them.
"""

def __init__(self, pipeline: Pipeline):
def __init__(self, pipeline: Pipeline, run_id: Optional[str] = None):
self.pipeline = pipeline
self.event_notifier = EventNotifier(pipeline.callbacks)
self.run_id = str(uuid.uuid4())
self.event_notifier = self.pipeline.event_notifier
self.run_id = run_id or str(uuid.uuid4())

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
"""Get inputs and run a specific task. Once the task is done,
Expand Down Expand Up @@ -265,9 +264,5 @@ async def run(self, data: dict[str, Any]) -> None:
(node without any parent). Then the callback on_task_complete
will handle the task dependencies.
"""
await self.event_notifier.notify_pipeline_started(self.run_id, data)
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
await asyncio.gather(*tasks)
await self.event_notifier.notify_pipeline_finished(
self.run_id, await self.pipeline.get_final_results(self.run_id)
)
55 changes: 32 additions & 23 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

import uuid

from neo4j_graphrag.utils.logging import prettify

try:
Expand All @@ -39,8 +41,7 @@
from neo4j_graphrag.experimental.pipeline.notification import (
Event,
EventCallbackProtocol,
EventType,
PipelineEvent,
EventNotifier,
)
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
from neo4j_graphrag.experimental.pipeline.pipeline_graph import (
Expand Down Expand Up @@ -103,7 +104,7 @@ async def run(
res = await self.execute(context, inputs)
end_time = default_timer()
logger.debug(
f"TASK FINISHED {self.name} in {end_time - start_time} res={prettify(res)}"
f"TASK FINISHED {self.name} in {round(end_time - start_time, 2)}s res={prettify(res)}"
)
return res

Expand All @@ -124,7 +125,6 @@ def __init__(
) -> None:
super().__init__()
self.store = store or InMemoryStore()
self.callbacks = [callback] if callback else []
self.final_results = InMemoryStore()
self.is_validated = False
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
Expand All @@ -139,6 +139,7 @@ def __init__(
}
"""
self.missing_inputs: dict[str, list[str]] = defaultdict()
self.event_notifier = EventNotifier([callback] if callback else [])

@classmethod
def from_template(
Expand Down Expand Up @@ -507,14 +508,13 @@ async def stream(
"""
# Create queue for events
event_queue: asyncio.Queue[Event] = asyncio.Queue()
run_id = None

async def event_stream(event: Event) -> None:
# Put event in queue for streaming
await event_queue.put(event)

# Add event streaming callback
self.callbacks.append(event_stream)
self.event_notifier.add_callback(event_stream)

event_queue_getter_task = None
try:
Expand Down Expand Up @@ -542,39 +542,48 @@ async def event_stream(event: Event) -> None:
# we are sure to get an Event here, since this is the only
# thing we put in the queue, but mypy still complains
event = event_future.result()
run_id = getattr(event, "run_id", None)
yield event # type: ignore

if exc := run_task.exception():
yield PipelineEvent(
event_type=EventType.PIPELINE_FAILED,
# run_id is null if pipeline fails before even starting
# ie during pipeline validation
run_id=run_id or "",
message=str(exc),
)
if raise_exception:
raise exc

finally:
# Restore original callback
self.callbacks.remove(event_stream)
self.event_notifier.remove_callback(event_stream)
if event_queue_getter_task and not event_queue_getter_task.done():
event_queue_getter_task.cancel()

async def run(self, data: dict[str, Any]) -> PipelineResult:
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self)
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)
run_id = str(uuid.uuid4())
logger.debug(f"PIPELINE START with {run_id=}")
try:
res = await self._run(run_id, data)
except Exception as e:
await self.event_notifier.notify_pipeline_failed(
run_id,
message=f"Pipeline failed with error {e}",
)
raise e
end_time = default_timer()
logger.debug(
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
f"PIPELINE FINISHED {run_id} in {round(end_time - start_time, 2)}s"
)
return PipelineResult(
return res

async def _run(self, run_id: str, data: dict[str, Any]) -> PipelineResult:
await self.event_notifier.notify_pipeline_started(run_id, data)
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self, run_id)
await orchestrator.run(data)
result = PipelineResult(
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)
await self.event_notifier.notify_pipeline_finished(
run_id,
await self.get_final_results(run_id),
)
return result
33 changes: 24 additions & 9 deletions tests/unit/experimental/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,23 @@ async def test_pipeline_event_notification() -> None:
previous_ts = actual_event.timestamp


@pytest.mark.asyncio
async def test_pipeline_event_notification_error_in_pipeline_run() -> None:
callback = AsyncMock(spec=EventCallbackProtocol)
pipe = Pipeline(callback=callback)
component_a = ComponentAdd()
component_b = ComponentAdd()
pipe.add_component(component_a, "a")
pipe.add_component(component_b, "b")
pipe.connect("a", "b", {"number1": "a.result"})

with pytest.raises(PipelineDefinitionError):
await pipe.run({"a": {"number1": 1, "number2": 2}})
assert len(callback.await_args_list) == 2
assert callback.await_args_list[0][0][0].event_type == EventType.PIPELINE_STARTED
assert callback.await_args_list[1][0][0].event_type == EventType.PIPELINE_FAILED


def test_event_model_no_warning(recwarn: Sized) -> None:
event = Event(
event_type=EventType.PIPELINE_STARTED,
Expand All @@ -503,7 +520,7 @@ async def test_pipeline_streaming_no_user_callback_happy_path() -> None:
assert len(events) == 2
assert events[0].event_type == EventType.PIPELINE_STARTED
assert events[1].event_type == EventType.PIPELINE_FINISHED
assert len(pipe.callbacks) == 0
assert len(pipe.event_notifier.callbacks) == 0


@pytest.mark.asyncio
Expand All @@ -515,7 +532,7 @@ async def test_pipeline_streaming_with_user_callback_happy_path() -> None:
events.append(e)
assert len(events) == 2
assert len(callback.call_args_list) == 2
assert len(pipe.callbacks) == 1
assert len(pipe.event_notifier.callbacks) == 1


@pytest.mark.asyncio
Expand All @@ -528,7 +545,7 @@ async def callback(event: Event) -> None:
async for e in pipe.stream({}):
events.append(e)
assert len(events) == 2
assert len(pipe.callbacks) == 1
assert len(pipe.event_notifier.callbacks) == 1


@pytest.mark.asyncio
Expand Down Expand Up @@ -557,11 +574,9 @@ async def test_pipeline_streaming_error_in_pipeline_definition() -> None:
with pytest.raises(PipelineDefinitionError):
async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}):
events.append(e)
# validation happens before pipeline run actually starts
# but we have the PIPELINE_FAILED event
assert len(events) == 1
assert events[0].event_type == EventType.PIPELINE_FAILED
assert events[0].run_id == ""
assert len(events) == 2
assert events[0].event_type == EventType.PIPELINE_STARTED
assert events[1].event_type == EventType.PIPELINE_FAILED


@pytest.mark.asyncio
Expand Down Expand Up @@ -589,4 +604,4 @@ async def callback(event: Event) -> None:
async for e in pipe.stream({}):
events.append(e)
assert len(events) == 2
assert len(pipe.callbacks) == 1
assert len(pipe.event_notifier.callbacks) == 1