Skip to content

Commit a18b514

Browse files
committed
Add async sleep statements and logging to record request time
1 parent f142c0f commit a18b514

File tree

3 files changed

+99
-11
lines changed

3 files changed

+99
-11
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

nemoguardrails/benchmark/mock_llm_server/api.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515

1616

17+
import asyncio
18+
import logging
1719
import time
1820
from typing import Annotated, Optional, Union
1921

20-
from fastapi import Depends, FastAPI, HTTPException
22+
from fastapi import Depends, FastAPI, HTTPException, Request, Response
2123

2224
from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config
2325
from nemoguardrails.benchmark.mock_llm_server.models import (
@@ -35,9 +37,28 @@
3537
from nemoguardrails.benchmark.mock_llm_server.response_data import (
3638
calculate_tokens,
3739
generate_id,
40+
get_latency_seconds,
3841
get_response,
3942
)
4043

44+
# Create a console logging handler
45+
log = logging.getLogger(__name__)
46+
log.setLevel(logging.INFO) # TODO Control this from the CLi args
47+
48+
# Create a formatter to define the log message format
49+
formatter = logging.Formatter(
50+
"%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
51+
)
52+
53+
# Create a console handler to print logs to the console
54+
console_handler = logging.StreamHandler()
55+
console_handler.setLevel(logging.INFO) # DEBUG and higher will go to the console
56+
console_handler.setFormatter(formatter)
57+
58+
# Add console handler to logs
59+
log.addHandler(console_handler)
60+
61+
4162
ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)]
4263

4364

@@ -60,6 +81,24 @@ def _validate_request_model(
6081
)
6182

6283

84+
@app.middleware("http")
85+
async def log_http_duration(request: Request, call_next):
86+
"""
87+
Middleware to log incoming requests and their responses.
88+
"""
89+
request_time = time.time()
90+
response = await call_next(request)
91+
response_time = time.time()
92+
93+
duration_seconds = response_time - request_time
94+
log.info(
95+
"Request finished: %s, took %.3f seconds",
96+
response.status_code,
97+
duration_seconds,
98+
)
99+
return response
100+
101+
63102
@app.get("/")
64103
async def root(config: ModelConfigDep):
65104
"""Root endpoint with basic server information."""
@@ -75,22 +114,30 @@ async def root(config: ModelConfigDep):
75114
@app.get("/v1/models", response_model=ModelsResponse)
76115
async def list_models(config: ModelConfigDep):
77116
"""List available models."""
117+
log.debug("/v1/models request")
118+
78119
model = Model(
79120
id=config.model, object="model", created=int(time.time()), owned_by="system"
80121
)
81-
return ModelsResponse(object="list", data=[model])
122+
response = ModelsResponse(object="list", data=[model])
123+
log.debug("/v1/models response: %s", response)
124+
return response
82125

83126

84127
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
85128
async def chat_completions(
86129
request: ChatCompletionRequest, config: ModelConfigDep
87130
) -> ChatCompletionResponse:
88131
"""Create a chat completion."""
132+
133+
log.debug("/v1/chat/completions request: %s", request)
134+
89135
# Validate model exists
90136
_validate_request_model(config, request)
91137

92138
# Generate dummy response
93139
response_content = get_response(config)
140+
response_latency_seconds = get_latency_seconds(config, seed=12345)
94141

95142
# Calculate token usage
96143
prompt_text = " ".join([msg.content for msg in request.messages])
@@ -122,7 +169,8 @@ async def chat_completions(
122169
total_tokens=prompt_tokens + completion_tokens,
123170
),
124171
)
125-
172+
await asyncio.sleep(response_latency_seconds)
173+
log.debug("/v1/chat/completions response: %s", response)
126174
return response
127175

128176

@@ -132,6 +180,8 @@ async def completions(
132180
) -> CompletionResponse:
133181
"""Create a text completion."""
134182

183+
log.debug("/v1/completions request: %s", request)
184+
135185
# Validate model exists
136186
_validate_request_model(config, request)
137187

@@ -143,6 +193,7 @@ async def completions(
143193

144194
# Generate dummy response
145195
response_text = get_response(config)
196+
response_latency_seconds = get_latency_seconds(config, seed=12345)
146197

147198
# Calculate token usage
148199
prompt_tokens = calculate_tokens(prompt_text)
@@ -171,10 +222,16 @@ async def completions(
171222
total_tokens=prompt_tokens + completion_tokens,
172223
),
173224
)
225+
226+
await asyncio.sleep(response_latency_seconds)
227+
log.debug("/v1/completions response: %s", response)
174228
return response
175229

176230

177231
@app.get("/health")
178232
async def health_check():
179233
"""Health check endpoint."""
180-
return {"status": "healthy", "timestamp": int(time.time())}
234+
log.debug("/health request")
235+
response = {"status": "healthy", "timestamp": int(time.time())}
236+
log.debug("/health response: %s", response)
237+
return response

nemoguardrails/benchmark/mock_llm_server/run_server.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,29 @@
2121
"""
2222

2323
import argparse
24+
import logging
2425
import sys
2526

2627
import uvicorn
28+
from uvicorn.logging import AccessFormatter
2729

2830
from nemoguardrails.benchmark.mock_llm_server.config import get_config, load_config
2931

32+
# 1. Get a logger instance
33+
log = logging.getLogger(__name__)
34+
log.setLevel(logging.DEBUG) # Set the lowest level to capture all messages
35+
36+
# Set up formatter and direct it to the console
37+
formatter = logging.Formatter(
38+
"%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
39+
)
40+
console_handler = logging.StreamHandler()
41+
console_handler.setLevel(logging.DEBUG) # DEBUG and higher will go to the console
42+
console_handler.setFormatter(formatter)
43+
44+
# Add the console handler for logging
45+
log.addHandler(console_handler)
46+
3047

3148
def main():
3249
parser = argparse.ArgumentParser(description="Run the Mock LLM Server")
@@ -64,11 +81,11 @@ def main():
6481
# Import the app after configuration is loaded. This caches the values in the app Dependencies
6582
from nemoguardrails.benchmark.mock_llm_server.api import app
6683

67-
print(f"Starting Mock LLM Server on {args.host}:{args.port}")
68-
print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs")
69-
print(f"Health check at: http://{args.host}:{args.port}/health")
70-
print(f"Model configuration: {model_config}")
71-
print("Press Ctrl+C to stop the server")
84+
log.info(f"Starting Mock LLM Server on {args.host}:{args.port}")
85+
log.info(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs")
86+
log.info(f"Health check at: http://{args.host}:{args.port}/health")
87+
log.info(f"Model configuration: {model_config}")
88+
log.info("Press Ctrl+C to stop the server")
7289

7390
try:
7491
uvicorn.run(
@@ -79,9 +96,9 @@ def main():
7996
log_level=args.log_level,
8097
)
8198
except KeyboardInterrupt:
82-
print("\nServer stopped by user")
99+
log.info("\nServer stopped by user")
83100
except Exception as e: # pylint: disable=broad-except
84-
print(f"Error starting server: {e}")
101+
log.error(f"Error starting server: {e}")
85102
sys.exit(1)
86103

87104

0 commit comments

Comments
 (0)