Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unit Tests for chat.ts #79

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/pages/api/__tests__/chat.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// src/pages/api/__tests__/chat.test.ts
import { handler } from '../chat.ts';
import { NextResponse } from 'next/server';
import { OpenAIStream } from '@/utils/server';

describe('handler', () => {
it('returns correct response for valid chat request', async () => {
const mockReq = {
json: jest.fn().mockResolvedValue({
model: { id: 'gpt-3' },
messages: [],
key: 'test-key',
prompt: 'test-prompt',
temperature: 0.5,
course_name: 'test-course',
stream: false,
}),
};

const mockOpenAIStream = jest.spyOn(OpenAIStream, 'default');
mockOpenAIStream.mockResolvedValue('test-response');

const response = await handler(mockReq as any);

expect(response).toBeInstanceOf(NextResponse);
expect(response.body).toEqual('test-response');
});

it('throws error for invalid chat request', async () => {
const mockReq = {
json: jest.fn().mockResolvedValue({}),
};

await expect(handler(mockReq as any)).rejects.toThrow();
});

it('correctly parses and encodes message from chat request', async () => {
const mockReq = {
json: jest.fn().mockResolvedValue({
model: { id: 'gpt-3' },
messages: [{ content: 'test-message' }],
key: 'test-key',
prompt: 'test-prompt',
temperature: 0.5,
course_name: 'test-course',
stream: false,
}),
};

const mockOpenAIStream = jest.spyOn(OpenAIStream, 'default');
mockOpenAIStream.mockResolvedValue('test-response');

await handler(mockReq as any);

expect(mockOpenAIStream).toHaveBeenCalledWith(
{ id: 'gpt-3' },
'test-promptOnly answer if it\'s related to the course materials.',
0.5,
'test-key',
[{ content: 'test-message' }],
false,
);
});

// Additional test cases to cover all branches of the code...
});
58 changes: 54 additions & 4 deletions src/pages/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,26 @@ const handler = async (req: Request): Promise<NextResponse> => {
console.log("Model's token limit", token_limit)

let promptToSend = prompt
if (!promptToSend) {
export function determinePrompt(prompt: string): string {
if (!prompt) {
return DEFAULT_SYSTEM_PROMPT;
}
return prompt;
}

if (!promptToSend) {
promptToSend = DEFAULT_SYSTEM_PROMPT
}

let temperatureToUse = temperature
if (temperatureToUse == null) {
export function determineTemperature(temperature: number | null): number {
if (temperature == null) {
return DEFAULT_TEMPERATURE;
}
return temperature;
}

if (temperatureToUse == null) {
temperatureToUse = DEFAULT_TEMPERATURE
}

Expand Down Expand Up @@ -104,7 +118,35 @@ const handler = async (req: Request): Promise<NextResponse> => {
}
}

// Take most recent N messages that will fit in the context window
export function prepareMessagesToSend(messages: OpenAIChatMessage[], encoding: Tiktoken, token_limit: number): OpenAIChatMessage[] {
let tokenCount = 0;
let messagesToSend: OpenAIChatMessage[] = [];

for (let i = messages.length - 1; i >= 0; i--) {
const message = messages[i];
if (message) {
let content: string;
if (typeof message.content === 'string') {
content = message.content;
} else {
content = message.content.map(c => c.text || '').join(' ');
}
const tokens = encoding.encode(content);

if (tokenCount + tokens.length + 1000 > token_limit) {
break;
}
tokenCount += tokens.length;
messagesToSend = [
{ role: message.role, content: message.content as Content[] },
...messagesToSend,
];
}
}
return messagesToSend;
}

// Take most recent N messages that will fit in the context window
const prompt_tokens = encoding.encode(promptToSend)

let tokenCount = prompt_tokens.length
Expand Down Expand Up @@ -148,7 +190,15 @@ const handler = async (req: Request): Promise<NextResponse> => {
messagesToSend,
stream
)
export function constructResponse(apiStream: string, stream: boolean): NextResponse {
if (stream) {
return new NextResponse(apiStream);
} else {
return new NextResponse(JSON.stringify(apiStream));
}
}

if (stream) {
return new NextResponse(apiStream)
} else {
return new NextResponse(JSON.stringify(apiStream))
Expand Down Expand Up @@ -177,4 +227,4 @@ const handler = async (req: Request): Promise<NextResponse> => {
}
}

export default handler
export { handler }
Loading