Skip to content

Commit

Permalink
clean up, refactor and fix turn based conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmacarthy committed May 5, 2024
1 parent 985e53a commit c99893d
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 153 deletions.
69 changes: 35 additions & 34 deletions src/common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export const EXTENSION_NAME = '@ext:rjmacarthy.twinny'
export const ASSISTANT = 'assistant'
export const USER = 'user'
export const TWINNY = '🤖 twinny'
export const SYSTEM = 'system'
export const YOU = '👤 You'
export const EMPTY_MESAGE = 'Sorry, I don’t understand. Please try again.'
export const MODEL_ERROR = 'Sorry, something went wrong...'
Expand Down Expand Up @@ -59,55 +60,55 @@ export const EVENT_NAME = {
}

export const TWINNY_COMMAND_NAME = {
enable: 'twinny.enable',
addTests: 'twinny.addTests',
addTypes: 'twinny.addTypes',
conversationHistory: 'twinny.conversationHistory',
disable: 'twinny.disable',
enable: 'twinny.enable',
explain: 'twinny.explain',
addTypes: 'twinny.addTypes',
refactor: 'twinny.refactor',
focusSidebar: 'twinny.sidebar.focus',
generateDocs: 'twinny.generateDocs',
addTests: 'twinny.addTests',
templateCompletion: 'twinny.templateCompletion',
stopGeneration: 'twinny.stopGeneration',
templates: 'twinny.templates',
getGitCommitMessage: 'twinny.getGitCommitMessage',
hideBackButton: 'twinny.hideBackButton',
manageProviders: 'twinny.manageProviders',
conversationHistory: 'twinny.conversationHistory',
manageTemplates: 'twinny.manageTemplates',
hideBackButton: 'twinny.hideBackButton',
newChat: 'twinny.newChat',
openChat: 'twinny.openChat',
settings: 'twinny.settings',
refactor: 'twinny.refactor',
sendTerminalText: 'twinny.sendTerminalText',
getGitCommitMessage: 'twinny.getGitCommitMessage',
newChat: 'twinny.newChat',
focusSidebar: 'twinny.sidebar.focus'
settings: 'twinny.settings',
stopGeneration: 'twinny.stopGeneration',
templateCompletion: 'twinny.templateCompletion',
templates: 'twinny.templates'
}

export const CONVERSATION_EVENT_NAME = {
saveConversation: 'twinny.save-conversation',
getConversations: 'twinny.get-conversations',
setActiveConversation: 'twinny.set-active-conversation',
getActiveConversation: 'twinny.get-active-conversation',
getConversations: 'twinny.get-conversations',
removeConversation: 'twinny.remove-conversation',
saveConversation: 'twinny.save-conversation',
saveLastConversation: 'twinny.save-last-conversation',
removeConversation: 'twinny.remove-conversation'
setActiveConversation: 'twinny.set-active-conversation'
}

export const PROVIDER_EVENT_NAME = {
addProvider: 'twinny.add-provider',
copyProvider: 'twinny.copy-provider',
focusProviderTab: 'twinny.focus-provider-tab',
getActiveChatProvider: 'twinny.get-active-provider',
getActiveFimProvider: 'twinny.get-active-fim-provider',
getAllProviders: 'twinny.get-providers',
removeProvider: 'twinny.remove-provider',
resetProvidersToDefaults: 'twinny.reset-providers-to-defaults',
setActiveChatProvider: 'twinny.set-active-chat-provider',
setActiveFimProvider: 'twinny.set-active-fim-provider',
updateProvider: 'twinny.update-provider',
focusProviderTab: 'twinny.focus-provider-tab',
copyProvider: 'twinny.copy-provider',
resetProvidersToDefaults: 'twinny.reset-providers-to-defaults'
updateProvider: 'twinny.update-provider'
}

export const CONVERSATION_STORAGE_KEY = 'twinny.conversations'
export const ACTIVE_CONVERSATION_STORAGE_KEY = 'twinny.active-conversation'
export const ACTIVE_CHAT_PROVIDER_STORAGE_KEY = 'twinny.active-chat-provider'
export const ACTIVE_CONVERSATION_STORAGE_KEY = 'twinny.active-conversation'
export const ACTIVE_FIM_PROVIDER_STORAGE_KEY = 'twinny.active-fim-provider'
export const CONVERSATION_STORAGE_KEY = 'twinny.conversations'
export const INFERENCE_PROVIDERS_STORAGE_KEY = 'twinny.inference-providers'

export const WORKSPACE_STORAGE_KEY = {
Expand All @@ -120,35 +121,35 @@ export const WORKSPACE_STORAGE_KEY = {
}

export const EXTENSION_SETTING_KEY = {
fimModelName: 'fimModelName',
chatModelName: 'chatModelName',
apiProvider: 'apiProvider',
apiProviderFim: 'apiProviderFim'
apiProviderFim: 'apiProviderFim',
chatModelName: 'chatModelName',
fimModelName: 'fimModelName'
}

export const EXTENSION_CONTEXT_NAME = {
twinnyConversationHistory: 'twinnyConversationHistory',
twinnyGeneratingText: 'twinnyGeneratingText',
twinnyManageTemplates: 'twinnyManageTemplates',
twinnyManageProviders: 'twinnyManageProviders',
twinnyConversationHistory: 'twinnyConversationHistory'
twinnyManageTemplates: 'twinnyManageTemplates'
}

export const WEBUI_TABS = {
chat: 'chat',
templates: 'templates',
history: 'history',
providers: 'providers',
history: 'history'
templates: 'templates'
}

export const FIM_TEMPLATE_FORMAT = {
automatic: 'automatic',
codegemma: 'codegemma',
codellama: 'codellama',
custom: 'custom-template',
deepseek: 'deepseek',
llama: 'llama',
stableCode: 'stable-code',
starcoder: 'starcoder',
codegemma: 'codegemma',
custom: 'custom-template'
starcoder: 'starcoder'
}

export const STOP_LLAMA = ['<EOT>']
Expand Down Expand Up @@ -286,5 +287,5 @@ export const MULTI_LINE_REACT = [
'jsx_element',
'jsx_element',
'jsx_opening_element',
'jsx_self_closing_element',
'jsx_self_closing_element'
]
116 changes: 43 additions & 73 deletions src/extension/chat-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ import {
EXTENSION_CONTEXT_NAME,
EVENT_NAME,
WEBUI_TABS,
USER,
ACTIVE_CHAT_PROVIDER_STORAGE_KEY
ACTIVE_CHAT_PROVIDER_STORAGE_KEY,
SYSTEM,
USER
} from '../common/constants'
import {
StreamResponse,
StreamBodyBase,
ServerMessage,
TemplateData,
ChatTemplateData,
Message,
StreamRequestOptions
} from '../common/types'
Expand Down Expand Up @@ -73,7 +73,7 @@ export class ChatService {
return provider
}

private buildStreamRequest(prompt: string, messages?: Message[] | Message[]) {
private buildStreamRequest(messages?: Message[] | Message[]) {
const provider = this.getProvider()

if (!provider) return
Expand All @@ -90,7 +90,7 @@ export class ChatService {
}
}

const requestBody = createStreamRequestBody(provider.provider, prompt, {
const requestBody = createStreamRequestBody(provider.provider, {
model: provider.modelName,
numPredictChat: this._numPredictChat,
temperature: this._temperature,
Expand Down Expand Up @@ -192,61 +192,6 @@ export class ChatService {
} as ServerMessage)
}

private buildMesageRoleContent = async (
messages: Message[],
language?: CodeLanguageDetails
): Promise<Message[]> => {
const editor = window.activeTextEditor
const selection = editor?.selection
const selectionContext = editor?.document.getText(selection) || ''
const systemMessage = {
role: 'system',
content: await this._templateProvider?.readSystemMessageTemplate(
this._promptTemplate
)
}

if (messages.length > 0 && (language || selectionContext)) {
const lastMessage = messages[messages.length - 1]

const detailsToAppend = []

if (language?.langName) {
detailsToAppend.push(`Language: ${language.langName}`)
}

if (selectionContext) {
detailsToAppend.push(`Selection: ${selectionContext}`)
}

const detailsString = detailsToAppend.length
? `\n\n${detailsToAppend.join(': ')}`
: ''

const updatedLastMessage = {
...lastMessage,
content: `${lastMessage.content}${detailsString}`
}

messages[messages.length - 1] = updatedLastMessage
}

return [systemMessage, ...messages]
}

private buildChatPrompt = async (messages: Message[]) => {
const editor = window.activeTextEditor
const selection = editor?.selection
const selectionContext = editor?.document.getText(selection) || ''
const prompt =
await this._templateProvider?.renderTemplate<ChatTemplateData>('chat', {
code: selectionContext || '',
messages,
role: USER
})
return prompt || ''
}

private buildTemplatePrompt = async (
template: string,
language: CodeLanguageDetails,
Expand Down Expand Up @@ -307,9 +252,30 @@ export class ChatService {
public async streamChatCompletion(messages: Message[]) {
this._completion = ''
this.sendEditorLanguage()
const messageRoleContent = await this.buildMesageRoleContent(messages)
const prompt = await this.buildChatPrompt(messages)
const request = this.buildStreamRequest(prompt, messageRoleContent)
const editor = window.activeTextEditor
const selection = editor?.selection
const selectionContext = editor?.document.getText(selection)

const systemMessage = {
role: SYSTEM,
content: await this._templateProvider?.readSystemMessageTemplate(
this._promptTemplate
)
}

const conversation = [
systemMessage,
...messages,
]

if (selectionContext) {
conversation.push({
role: USER,
content: `Use this code as a context for the next response: ${selectionContext}`
})
}

const request = this.buildStreamRequest(conversation)
if (!request) return
const { requestBody, requestOptions } = request
return this.streamResponse({ requestBody, requestOptions })
Expand Down Expand Up @@ -347,16 +313,20 @@ export class ChatService {
} as ServerMessage)
}

const messageRoleContent = await this.buildMesageRoleContent(
[
{
content: prompt,
role: 'user'
}
],
language
)
const request = this.buildStreamRequest(prompt, messageRoleContent)
const systemMessage = {
role: SYSTEM,
content: await this._templateProvider?.readSystemMessageTemplate(
this._promptTemplate
)
}

const request = this.buildStreamRequest([
systemMessage,
{
role: USER,
content: prompt
}
])
if (!request) return
const { requestBody, requestOptions } = request
return this.streamResponse({ requestBody, requestOptions, onEnd })
Expand Down
2 changes: 1 addition & 1 deletion src/extension/conversation-history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ export class ConversationHistory {
}
}

const requestBody = createStreamRequestBody(provider.provider, '', {
const requestBody = createStreamRequestBody(provider.provider, {
model: provider.modelName,
numPredictChat: 100,
temperature: this._temperature,
Expand Down
9 changes: 0 additions & 9 deletions src/extension/provider-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import {

export function createStreamRequestBody(
provider: string,
prompt: string,
options: {
temperature: number
numPredictChat: number
Expand All @@ -23,7 +22,6 @@ export function createStreamRequestBody(
case ApiProviders.OpenWebUI:
return {
model: options.model,
prompt,
stream: true,
messages: options.messages,
keep_alive: options.keepAlive,
Expand All @@ -32,13 +30,6 @@ export function createStreamRequestBody(
num_predict: options.numPredictChat
}
}
case ApiProviders.LlamaCpp:
return {
prompt,
stream: true,
temperature: options.temperature,
n_predict: options.numPredictChat
}
case ApiProviders.LiteLLM:
default:
return {
Expand Down
4 changes: 2 additions & 2 deletions src/extension/template-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as Handlebars from 'handlebars'
import * as path from 'path'
import { DefaultTemplate } from '../common/types'
import { defaultTemplates } from './templates'
import { DEFAULT_TEMPLATE_NAMES } from '../common/constants'
import { DEFAULT_TEMPLATE_NAMES, SYSTEM } from '../common/constants'

export class TemplateProvider {
private _basePath: string
Expand Down Expand Up @@ -103,7 +103,7 @@ export class TemplateProvider {
}

private filterSystemTemplates = (filterName: string) => {
return filterName !== 'chat' && filterName.includes('system') === false
return filterName !== 'chat' && filterName.includes(SYSTEM) === false
}

public listTemplates(): string[] {
Expand Down
Loading

0 comments on commit c99893d

Please sign in to comment.