Skip to content

Commit

Permalink
Add support for mulitple models
Browse files Browse the repository at this point in the history
  • Loading branch information
christophebe committed Mar 3, 2024
1 parent a354dbb commit 912a2ef
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 77 deletions.
19 changes: 9 additions & 10 deletions src/bin/command/post-template.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ import fs from 'fs'
import { Command } from 'commander'
import { marked as markdownToHTML } from 'marked'
import { PostTemplateGenerator } from '../../post-generator'
import { TemplatePost, TemplatePostPrompt } from 'src/types'
import { getFileExtension, isHTML, isMarkdown } from 'src/lib/template'
import { debug } from 'console'
import { DEFAULT_LLM, TemplatePostPrompt, getLLMs, llm } from 'src/types'
import { getFileExtension, isMarkdown } from 'src/lib/template'

type Options = {
debug: boolean
debugapi: boolean
apiKey: string
templateFile: string
model: 'gpt-4-turbo-preview' | 'gpt-4' | 'gpt-3.5-turbo'
promptFolder: string
model: llm
filename: string
temperature: number
frequencyPenalty: number
Expand All @@ -33,12 +33,12 @@ export function buildPostTemplateCommands (program: Command) {
}, {})
return { ...previous, ...obj }
}, {})
.option('-m, --model <model>', 'Set the LLM : "gpt-4-turbo-preview" | "gpt-4" | "gpt-3.5-turbo" (optional), gpt-4-turbo-preview by default')
.option('-m, --model <model>', `Set the LLM : ${getLLMs().join('| ')}`)

.option('-f, --filename <filename>', 'Set the post file name (optional)')
.option('-tt, --temperature <temperature>', 'Set the temperature (optional)')
.option('-fp, --frequencypenalty <frequencyPenalty>', 'Set the frequency penalty (optional)')
.option('-pp, --presencepenalty <presencePenalty>', 'Set the presence penalty (optional)')
.option('-lb, --logitbias <logitBias>', 'Set the logit bias (optional)')
.option('-d, --debug', 'Output extra debugging (optional)')
.option('-da, --debugapi', 'Debug the api calls (optional)')
.option('-k, --apiKey <key>', 'Set the OpenAI api key (optional, you can also set the OPENAI_API_KEY environment variable)')
Expand All @@ -57,6 +57,7 @@ async function generatePost (options: Options) {
const postPrompt: TemplatePostPrompt = {
...defaultPostPrompt,
...options

}

const postGenerator = new PostTemplateGenerator(postPrompt)
Expand Down Expand Up @@ -84,11 +85,9 @@ async function generatePost (options: Options) {

function buildDefaultPostPrompt () {
return {
model: 'gpt-4-turbo-preview',
model: DEFAULT_LLM,
temperature: 0.8,
frequencyPenalty: 0,
presencePenalty: 1,
logitBias: 0

presencePenalty: 1
}
}
14 changes: 7 additions & 7 deletions src/bin/command/post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Command } from 'commander'
import { marked as markdownToHTML } from 'marked'
import { askQuestions } from '../question/questions'
import { PostGenerator } from '../../post-generator'
import { Post, AutoPostPrompt } from 'src/types'
import { Post, AutoPostPrompt, llm, DEFAULT_LLM, getLLMs } from 'src/types'

type Options = {
interactive: boolean
Expand All @@ -12,8 +12,9 @@ type Options = {
apiKey: string
templateFile: string
language: string
model: 'gpt-4-turbo-preview' | 'gpt-4' | 'gpt-3.5-turbo'
model: llm
filename: string
promptFolder: string
topic: string
country: string
generate: boolean // generate the audience and intent
Expand All @@ -29,7 +30,7 @@ export function buildPostCommands (program: Command) {
.description('Generate a post in interactive or automatic mode')
.option('-i, --interactive', 'Use interactive mode (CLI questions)')
.option('-l, --language <language>', 'Set the language (optional), english by default')
.option('-m, --model <model>', 'Set the LLM : "gpt-4-turbo-preview" | "gpt-4" | "gpt-3.5-turbo" (optional), gpt-4-turbo-preview by default')
.option('-m, --model <model>', `Set the LLM : ${getLLMs().join('| ')}`)
.option('-f, --filename <filename>', 'Set the post file name (optional)')
.option('-pf, --promptfolder <promptFolder>', 'Use custom prompt define in this folder (optional)')
.option('-tp, --topic <topic>', 'Set the post topic (optional)')
Expand All @@ -39,7 +40,6 @@ export function buildPostCommands (program: Command) {
.option('-tt, --temperature <temperature>', 'Set the temperature (optional)')
.option('-fp, --frequencypenalty <frequencyPenalty>', 'Set the frequency penalty (optional)')
.option('-pp, --presencepenalty <presencePenalty>', 'Set the presence penalty (optional)')
.option('-lb, --logitbias <logitBias>', 'Set the logit bias (optional)')
.option('-d, --debug', 'Output extra debugging (optional)')
.option('-da, --debugapi', 'Debug the api calls (optional)')
.option('-k, --apiKey <key>', 'Set the OpenAI api key (optional, you can also set the OPENAI_API_KEY environment variable)')
Expand Down Expand Up @@ -98,13 +98,13 @@ function isInteractive (options : Options) {

function buildDefaultPostPrompt () : AutoPostPrompt {
return {
model: 'gpt-4-turbo-preview',
model: DEFAULT_LLM,
language: 'english',
withConclusion: true,
temperature: 0.8,
frequencyPenalty: 1,
presencePenalty: 1,
logitBias: 0
presencePenalty: 1

}
}

Expand Down
12 changes: 4 additions & 8 deletions src/bin/question/questions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inquirer from 'inquirer'
import inquirerPrompt from 'inquirer-autocomplete-prompt'
import inquirerFileTreeSelection from 'inquirer-file-tree-selection-prompt'
import { DEFAULT_LLM, getLLMs } from 'src/types'

inquirer.registerPrompt('autocomplete', inquirerPrompt)
inquirer.registerPrompt('file-tree-selection', inquirerFileTreeSelection)
Expand All @@ -11,7 +12,8 @@ const LANGUAGES = ['english', 'french', 'spanish', 'german', 'italian', 'russian
'slovak', 'croatian', 'ukrainian', 'slovene', 'estonian', 'latvian', 'lithuanian',
'chinese', 'hindi', 'arabic', 'japanese']

const MODELS = ['gpt-4-turbo-preview', 'gpt-4', 'gpt-3.5-turbo']
const MODELS = getLLMs()

const questions = [
{
type: 'autocomplete',
Expand All @@ -25,7 +27,7 @@ const questions = [
name: 'model',
message: 'AI model ?',
choices: MODELS,
default: 'gpt-4-turbo-preview'
default: DEFAULT_LLM
},
{
type: 'input',
Expand Down Expand Up @@ -84,12 +86,6 @@ const questions = [
name: 'presencePenalty',
message: 'Presence Penalty (-2/2) ?',
default: 1
},
{
type: 'number',
name: 'logitBias',
message: 'Logit bias ?',
default: 0
}

]
Expand Down
45 changes: 45 additions & 0 deletions src/lib/llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models'
import { ChatMistralAI } from '@langchain/mistralai'
import { ChatOpenAI } from '@langchain/openai'
import { BasePostPrompt } from '../types'

export function buildLLM (postPrompt: BasePostPrompt, forJson: boolean = false): BaseChatModel {
switch (postPrompt.model) {
case 'gpt-4':
case 'gpt-4-turbo-preview':
return buildOpenAI(postPrompt, forJson)
case 'mistral-small-latest':
case 'mistral-medium-latest':
case 'mistral-large-latest':
return buildMistral(postPrompt, forJson)

default:
return buildOpenAI(postPrompt, forJson)
}
}

function buildOpenAI (postPrompt: BasePostPrompt, forJson: boolean = false) {
const llmParams = {
modelName: postPrompt.model.toString(),
temperature: postPrompt.temperature ?? 0.8,
frequencyPenalty: forJson ? 0 : postPrompt.frequencyPenalty ?? 1,
presencePenalty: forJson ? 0 : postPrompt.presencePenalty ?? 1,
verbose: postPrompt.debugapi,
openAIApiKey: postPrompt.apiKey

}
return new ChatOpenAI(llmParams)
}

function buildMistral (postPrompt: BasePostPrompt, forJson: boolean = false) {
const llmParams = {
modelName: postPrompt.model.toString(),
temperature: postPrompt.temperature ?? 0.8,
frequencyPenalty: forJson ? 0 : postPrompt.frequencyPenalty ?? 1,
presencePenalty: forJson ? 0 : postPrompt.presencePenalty ?? 1,
verbose: postPrompt.debugapi,
apiKey: postPrompt.apiKey

}
return new ChatMistralAI(llmParams)
}
4 changes: 2 additions & 2 deletions src/lib/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { isHTML, isMarkdown } from './template'
const HeadingSchema: z.ZodSchema<any> = z.object({
title: z.string(),
keywords: z.array(z.string()).optional(),
headings: z.array(z.lazy(() => PostOutlineSchema)).optional()
headings: z.array(z.lazy(() => HeadingSchema)).optional()
})

const PostOutlineSchema = z.object({
Expand Down Expand Up @@ -59,7 +59,7 @@ export class HTMLOutputParser extends BaseOutputParser<string> {

getFormatInstructions (): string {
return `
Your answer has to be only a HRML block.
Your answer has to be only a HTML block.
The block has to delimited by \`\`\`html (beginning of the block) and \`\`\` (end of the block)
`
}
Expand Down
84 changes: 42 additions & 42 deletions src/post-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ import {
} from './lib/prompt'

import { Template } from './lib/template'
import { log } from 'console'
import { buildLLM } from './lib/llm'
import { BaseChatModel } from '@langchain/core/language_models/chat_models'

dotenv.config()
const readFile = promisify(rd)
Expand All @@ -52,38 +53,27 @@ const PARSER_INSTRUCTIONS_TAG = '\n{formatInstructions}\n'
// - Conclusion (optional)
// -----------------------------------------------------------------------------------------
export class PostGenerator {
private llm_json: ChatOpenAI
private llm_content: ChatOpenAI
private llm_json: BaseChatModel
private llm_content: BaseChatModel
private memory : BufferMemory
private log
private promptFolder: string

public constructor (private postPrompt: AutoPostPrompt) {
this.log = createLogger(postPrompt.debug ? 'debug' : 'info')

if (this.postPrompt.promptFolder) {
if (this.postPrompt.promptFolder && postPrompt.promptFolder !== '') {
this.log.info('Use prompts from folder : ' + this.postPrompt.promptFolder)
}

this.promptFolder = postPrompt.promptFolder ?? path.join(__dirname, DEFAULT_PROMPT_FOLDER)

this.llm_content = new ChatOpenAI({
modelName: postPrompt.model,
temperature: postPrompt.temperature ?? 0.8,
frequencyPenalty: postPrompt.frequencyPenalty ?? 0,
presencePenalty: postPrompt.presencePenalty ?? 1,
verbose: postPrompt.debugapi,
openAIApiKey: postPrompt.apiKey
})
this.promptFolder = postPrompt.promptFolder && postPrompt.promptFolder !== ''
? postPrompt.promptFolder
: path.join(__dirname, DEFAULT_PROMPT_FOLDER)

this.llm_content = buildLLM(postPrompt) as BaseChatModel
// For the outline, we use a different setting without frequencyPenalty and presencePenalty
// in order to avoid some json format issue
this.llm_json = new ChatOpenAI({
modelName: postPrompt.model,
temperature: postPrompt.temperature ?? 0.8,
verbose: postPrompt.debugapi,
openAIApiKey: postPrompt.apiKey
})
this.llm_json = buildLLM(postPrompt, true)

this.memory = new BufferMemory({
returnMessages: true
Expand Down Expand Up @@ -226,7 +216,7 @@ export class PostGenerator {
{ input: outlineMessage },
{ output: this.postOutlineToMarkdown(outline) }
)
this.log.debug('OUTLINE :\n\n')
this.log.debug(' ----------------------OUTLINE ----------------------')
this.log.debug(JSON.stringify(outline, null, 2))

return outline
Expand Down Expand Up @@ -320,7 +310,17 @@ export class PostGenerator {
keywords: heading.keywords?.join(', ')
}

const content = await chain.invoke(inputVariables)
let content = await chain.invoke(inputVariables)

if (content === '' || content === null) {
this.log.warn(`🤷🏻‍♂️ No content generated for heading : ${heading.title} with the model : ${this.postPrompt.model}`)
content = `🤷🏻‍♂️ No content generated with the model: ${this.postPrompt.model}`
}

this.log.debug(' ---------------------- HEADING : ' + heading.title + '----------------------')
this.log.debug(content)
this.log.debug(' ---------------------- HEADING END ----------------------')

this.memory.saveContext(
{ input: `Write a content for the heading : ${heading.title}` },
{ output: content }
Expand Down Expand Up @@ -367,7 +367,17 @@ export class PostGenerator {
language: this.postPrompt.language
}

const content = await chain.invoke(inputVariables)
let content = await chain.invoke(inputVariables)

if (content === '' || content === null) {
this.log.warn(`🤷🏻‍♂️ No content generated with the model : ${this.postPrompt.model}`)
content = `🤷🏻‍♂️ No content generated with the model: ${this.postPrompt.model}`
}

this.log.debug(' ---------------------- CONTENT ----------------------')
this.log.debug(content)
this.log.debug(' ---------------------- CONTENT END ----------------------')

this.memory.saveContext(
{ input: memoryInput },
{ output: content }
Expand Down Expand Up @@ -432,8 +442,8 @@ export class PostGenerator {
// A template is a file containing prompts that will be replaced by the content
// -----------------------------------------------------------------------------------------
export class PostTemplateGenerator {
private llm_content: ChatOpenAI
private llm_json: ChatOpenAI
private llm_content: BaseChatModel
private llm_json: BaseChatModel
private memory: BufferMemory
private log
private promptFolder: string
Expand All @@ -445,24 +455,14 @@ export class PostTemplateGenerator {
this.log.info('Use prompts from folder : ' + this.postPrompt.promptFolder)
}

this.promptFolder = postPrompt.promptFolder ?? path.join(__dirname, DEFAULT_PROMPT_FOLDER)

this.llm_content = new ChatOpenAI({
modelName: postPrompt.model,
temperature: postPrompt.temperature ?? 0.8,
frequencyPenalty: postPrompt.frequencyPenalty ?? 0,
presencePenalty: postPrompt.presencePenalty ?? 0,
verbose: postPrompt.debugapi,
openAIApiKey: postPrompt.apiKey
})

this.llm_json = new ChatOpenAI({
modelName: postPrompt.model,
temperature: postPrompt.temperature ?? 0.8,
verbose: postPrompt.debugapi,
openAIApiKey: postPrompt.apiKey
this.promptFolder = postPrompt.promptFolder && postPrompt.promptFolder !== ''
? postPrompt.promptFolder
: path.join(__dirname, DEFAULT_PROMPT_FOLDER)

})
this.llm_content = buildLLM(postPrompt)
// For the outline, we use a different setting without frequencyPenalty and presencePenalty
// in order to avoid some json format issue
this.llm_json = buildLLM(postPrompt, true)

this.memory = new BufferMemory({
returnMessages: true
Expand Down
26 changes: 24 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
export type llm =
'gpt-4' |
'gpt-4-turbo-preview' |
'mistral-small-latest' |
'mistral-medium-latest' |
'mistral-large-latest' |
'claude' |
'groq'

export const DEFAULT_LLM : llm = 'gpt-4-turbo-preview'

export function getLLMs (): llm[] {
return [
'gpt-4',
'gpt-4-turbo-preview',
'mistral-small-latest',
'mistral-medium-latest',
'mistral-large-latest',
'claude',
'groq'
]
}

export type BasePostPrompt = {
model: 'gpt-4-turbo-preview' | 'gpt-4' | 'gpt-3.5-turbo'
model: llm
temperature?: number
frequencyPenalty?: number
presencePenalty?: number
logitBias?: number
debug?: boolean
debugapi?: boolean
apiKey?: string
Expand Down
Loading

0 comments on commit 912a2ef

Please sign in to comment.