Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sglang or vllm api interface #51

Open
devops724 opened this issue Feb 27, 2025 · 1 comment
Open

sglang or vllm api interface #51

devops724 opened this issue Feb 27, 2025 · 1 comment

Comments

@devops724
Copy link

🚀 The feature, motivation and pitch

here is how to launch a api server using sglang for it quick start page
Launch A Server
from sglang.test.test_utils import is_in_ci
from sglang.utils import wait_for_server, print_highlight, terminate_process

if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd

This is equivalent to running the following command in your terminal

python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0

server_process, port = launch_server_cmd(
"""
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
--host 0.0.0.0
"""
)

wait_for_server(f"http://localhost:{port}")
i want run olmocr over sglang server
is there such feature ?

Alternatives

No response

Additional context

No response

@devops724
Copy link
Author

devops724 commented Feb 27, 2025

for every one may like run this model in openapi capability server
here is code how you can run it

import torch
import base64
import uvicorn
import json
import argparse
import os
from io import BytesIO
from typing import List, Dict, Any, Optional, Union
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
import time
import uuid

# Parse command line arguments
parser = argparse.ArgumentParser(description="OLMoCR API Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
parser.add_argument("--model-id", type=str, default="allenai/olmOCR-7B-0225-preview", help="Model ID to load")
args = parser.parse_args()

# Create FastAPI app
app = FastAPI(
    title="OLMoCR API",
    description="OpenAPI-compatible REST API for OLMoCR OCR and document understanding model",
    version="0.1.0",
)

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define models for API
class ImageUrl(BaseModel):
    url: str

class ContentItem(BaseModel):
    type: str
    text: Optional[str] = None
    image_url: Optional[ImageUrl] = None

class Message(BaseModel):
    role: str
    content: Union[str, List[ContentItem]]

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[Message]
    max_tokens: int = 300
    temperature: float = 0.8
    stream: bool = False
    
class CompletionChoice(BaseModel):
    index: int
    message: Message
    finish_reason: str = "stop"

class CompletionUsage(BaseModel):
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int

class ChatCompletionResponse(BaseModel):
    id: str = Field(..., example="chatcmpl-123456789")
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[CompletionChoice]
    usage: CompletionUsage

# Load the model and processor
print(f"Loading model: {args.model_id}")
device = torch.device(args.device)
model = Qwen2VLForConditionalGeneration.from_pretrained(args.model_id, torch_dtype=torch.bfloat16).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model.to(device)
print(f"Model loaded on {device}")

# Define routes
@app.get("/")
async def root():
    return {"message": "OLMoCR OpenAPI Server", "status": "running"}

@app.get("/v1/models")
async def get_models():
    """List available models"""
    return {
        "object": "list",
        "data": [
            {
                "id": args.model_id,
                "object": "model",
                "created": int(time.time()),
                "owned_by": "allenai"
            }
        ]
    }

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
    """Process a chat completion request with image inputs"""
    try:
        # Only support the last message for now
        last_message = request.messages[-1]
        
        # Process the content
        if isinstance(last_message.content, str):
            # If content is just a string, treat it as a text prompt
            prompt = last_message.content
            image_data = None
        else:
            # Process the content items
            prompt = None
            image_data = None
            
            for item in last_message.content:
                if item.type == "text":
                    prompt = item.text
                elif item.type == "image_url":
                    image_url = item.image_url.url
                    
                    # Handle base64 encoded images
                    if image_url.startswith("data:image"):
                        # Extract the base64 part
                        image_data = image_url.split(",")[1]
                    # Handle PDF URLs (very basic detection)
                    elif image_url.endswith(".pdf"):
                        # Download the PDF and convert to image
                        import urllib.request
                        import tempfile
                        
                        with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
                            urllib.request.urlretrieve(image_url, tmp_file.name)
                            # Convert PDF to image
                            image_data = render_pdf_to_base64png(tmp_file.name, 1, target_longest_image_dim=1024)
                            
                            # Build anchor text if no prompt given
                            if not prompt:
                                anchor_text = get_anchor_text(tmp_file.name, 1, pdf_engine="pdfreport", target_length=4000)
                                prompt = build_finetuning_prompt(anchor_text)
                            
                            # Clean up
                            os.unlink(tmp_file.name)
                    else:
                        # Download the image
                        import urllib.request
                        
                        with urllib.request.urlopen(image_url) as response:
                            image_data = base64.b64encode(response.read()).decode("utf-8")
        
        if not image_data:
            raise HTTPException(status_code=400, detail="No image data provided")
        
        if not prompt:
            prompt = "Extract and describe the text content from this image."
        
        # Process the input for the model
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
                ],
            }
        ]
        
        # Apply chat template
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        # Process the image
        main_image = Image.open(BytesIO(base64.b64decode(image_data)))
        inputs = processor(
            text=[text],
            images=[main_image],
            padding=True,
            return_tensors="pt",
        )
        inputs = {key: value.to(device) for (key, value) in inputs.items()}
        
        # Count input tokens
        input_token_count = inputs["input_ids"].shape[1]
        
        # Generate output
        start_time = time.time()
        output = model.generate(
            **inputs,
            temperature=request.temperature,
            max_new_tokens=request.max_tokens,
            num_return_sequences=1,
            do_sample=True,
        )
        end_time = time.time()
        
        # Decode output
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = output[:, prompt_length:]
        text_output = processor.tokenizer.batch_decode(
            new_tokens, skip_special_tokens=True
        )[0]
        
        # Count output tokens
        output_token_count = new_tokens.shape[1]
        
        # Create response
        completion_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
        response = ChatCompletionResponse(
            id=completion_id,
            object="chat.completion",
            created=int(time.time()),
            model=request.model,
            choices=[
                CompletionChoice(
                    index=0,
                    message=Message(
                        role="assistant",
                        content=text_output
                    ),
                    finish_reason="stop"
                )
            ],
            usage=CompletionUsage(
                prompt_tokens=input_token_count,
                completion_tokens=output_token_count,
                total_tokens=input_token_count + output_token_count
            )
        )
        
        print(f"Request processed in {end_time - start_time:.2f}s, generated {output_token_count} tokens")
        return response
    
    except Exception as e:
        print(f"Error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    print(f"Starting OLMoCR API server on {args.host}:{args.port}")
    uvicorn.run(app, host=args.host, port=args.port)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant