Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions mlx_vlm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class OpenAIResponse(BaseModel):

class BaseStreamEvent(BaseModel):
type: str
sequence_number: int


class ContentPartOutputText(BaseModel):
Expand Down Expand Up @@ -344,6 +345,16 @@ class ResponseCompletedEvent(BaseStreamEvent):
]


class SequenceNumber:
def __init__(self):
self._sequence_number = 0

def get_and_increment(self) -> int:
num = self._sequence_number
self._sequence_number += 1
return num


# OpenAI endpoint
@app.post("/responses")
async def openai_endpoint(request: Request):
Expand Down Expand Up @@ -486,6 +497,8 @@ def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, mode
message_id = f"msg_{uuid.uuid4().hex}"

if openai_request.stream:
sequence_number = SequenceNumber()

# Streaming response
async def stream_generator():
token_iterator = None
Expand All @@ -511,10 +524,10 @@ async def stream_generator():
)

# Send response.created event (to match the openai pipeline)
yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response).model_dump_json()}\n\n"
yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Send response.in_progress event (to match the openai pipeline)
yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response).model_dump_json()}\n\n"
yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Send response.output_item.added event (to match the openai pipeline)
message_item = MessageItem(
Expand All @@ -524,13 +537,13 @@ async def stream_generator():
role="assistant",
content=[],
)
yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item).model_dump_json()}\n\n"
yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Send response.content_part.added event
content_part = ContentPartOutputText(
type="output_text", text="", annotations=[]
)
yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part).model_dump_json()}\n\n"
yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Stream text deltas
token_iterator = stream_generate(
Expand Down Expand Up @@ -558,17 +571,17 @@ async def stream_generator():
}

# Send response.output_text.delta event
yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta).model_dump_json()}\n\n"
yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"
await asyncio.sleep(0.01)

# Send response.output_text.done event (to match the openai pipeline)
yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text).model_dump_json()}\n\n"
yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Send response.content_part.done event (to match the openai pipeline)
final_content_part = ContentPartOutputText(
type="output_text", text=full_text, annotations=[]
)
yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part).model_dump_json()}\n\n"
yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Send response.output_item.done event (to match the openai pipeline)
final_message_item = MessageItem(
Expand All @@ -578,22 +591,22 @@ async def stream_generator():
role="assistant",
content=[final_content_part],
)
yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item).model_dump_json()}\n\n"
yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

# Send response.completed event (to match the openai pipeline)
completed_response = base_response.model_copy(
update={
"status": "completed",
"output": [final_message_item],
"usage": {
"input_tokens": usage_stats["input_tokens"],
"output_tokens": usage_stats["output_tokens"],
"total_tokens": usage_stats["input_tokens"]
"usage": OpenAIUsage(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes this Pydantic warning:

/Users/dwohlfahrt/workspace/github/dwohlfahrt/mlx-testbed/.venv/lib/python3.13/site-packages/pydantic/main.py:519: UserWarning: Pydantic serializer warnings:
  PydanticSerializationUnexpectedValue(Expected `OpenAIUsage` - serialized value may not be as expected [input_value={'input_tokens': 33, 'out... 13, 'total_tokens': 46}, input_type=dict])
  return self.__pydantic_serializer__.to_json(

input_tokens=usage_stats["input_tokens"],
output_tokens=usage_stats["output_tokens"],
total_tokens=usage_stats["input_tokens"]
+ usage_stats["output_tokens"],
},
),
}
)
yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response).model_dump_json()}\n\n"
yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response, sequence_number=sequence_number.get_and_increment()).model_dump_json()}\n\n"

except Exception as e:
print(f"Error during stream generation: {e}")
Expand Down