Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions chatkit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,9 @@ async def _process_events(
with agents_sdk_user_agent_override():
async for event in stream():
if isinstance(event, ThreadItemAddedEvent):
pending_items[event.item.id] = event.item
# Stash an isolated copy in case we need to persist unfinished items
# on cancellation; downstream handlers keep using the original event.item.
pending_items[event.item.id] = event.item.model_copy(deep=True)

match event:
case ThreadItemDoneEvent():
Expand Down Expand Up @@ -779,27 +781,25 @@ def _apply_assistant_message_update(
| AssistantMessageContentPartAnnotationAdded
| AssistantMessageContentPartDone,
) -> AssistantMessageItem:
updated = item.model_copy(deep=True)

# Pad the content list so the requested content_index exists before we write into it.
# (Streaming updates can arrive for an index that hasn’t been created yet)
while len(updated.content) <= update.content_index:
updated.content.append(AssistantMessageContent(text="", annotations=[]))
while len(item.content) <= update.content_index:
item.content.append(AssistantMessageContent(text="", annotations=[]))

match update:
case AssistantMessageContentPartAdded():
updated.content[update.content_index] = update.content
item.content[update.content_index] = update.content
case AssistantMessageContentPartTextDelta():
updated.content[update.content_index].text += update.delta
item.content[update.content_index].text += update.delta
case AssistantMessageContentPartAnnotationAdded():
annotations = updated.content[update.content_index].annotations
annotations = item.content[update.content_index].annotations
if update.annotation_index <= len(annotations):
annotations.insert(update.annotation_index, update.annotation)
else:
annotations.append(update.annotation)
case AssistantMessageContentPartDone():
updated.content[update.content_index] = update.content
return updated
item.content[update.content_index] = update.content
return item

def _update_pending_items(
self,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "openai-chatkit"
version = "1.5.0"
version = "1.5.1"
description = "A ChatKit backend SDK."
readme = "README.md"
requires-python = ">=3.10"
Expand Down
63 changes: 63 additions & 0 deletions tests/test_chatkit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AttachmentsDeleteReq,
AttachmentUploadDescriptor,
ClientToolCallItem,
CustomTask,
FeedbackKind,
FileAttachment,
ImageAttachment,
Expand Down Expand Up @@ -79,6 +80,9 @@
UserMessageTextContent,
WidgetItem,
WidgetRootUpdated,
Workflow,
WorkflowItem,
WorkflowTaskAdded,
)
from chatkit.widgets import Card, Text
from tests._types import RequestContext
Expand Down Expand Up @@ -354,6 +358,65 @@ def generate_item_id(
)


async def test_workflow_task_not_duplicated_on_done_event():
"""
Regression test to make sure pending item updates do not modify the
origin item that was streamed.
"""

async def responder(
thread: ThreadMetadata, input: UserMessageItem | None, context: Any
) -> AsyncIterator[ThreadStreamEvent]:
workflow_item = WorkflowItem(
id="workflow-item",
created_at=datetime.now(),
thread_id=thread.id,
workflow=Workflow(type="custom", tasks=[]),
)

yield ThreadItemAddedEvent(item=workflow_item)

task = CustomTask(title="foo", content="bar")
yield ThreadItemUpdatedEvent(
item_id=workflow_item.id,
update=WorkflowTaskAdded(task=task, task_index=0),
)

workflow_item.workflow.tasks.append(task)
yield ThreadItemDoneEvent(item=workflow_item)

with make_server(responder) as server:
events = await server.process_streaming(
ThreadsCreateReq(
params=ThreadCreateParams(
input=UserMessageInput(
content=[UserMessageTextContent(text="Hello")],
attachments=[],
inference_options=InferenceOptions(),
)
)
)
)

thread = next(e.thread for e in events if e.type == "thread.created")
workflow_done_event = next(
e
for e in events
if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
)
workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
assert len(workflow_done_item.workflow.tasks) == 1
assert workflow_done_item.workflow.tasks[0].title == "foo"

stored = await server.store.load_item(
thread.id, workflow_done_item.id, DEFAULT_CONTEXT
)
assert isinstance(stored, WorkflowItem)
stored_workflow = cast(WorkflowItem, stored)
assert len(stored_workflow.workflow.tasks) == 1
assert stored_workflow.workflow.tasks[0].title == "foo"


async def test_flows_context_to_responder():
responder_context = None
add_feedback_context = None
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.