From 912a2ef267ae024d1109e39a8a3648432b11628e Mon Sep 17 00:00:00 2001 From: Christophe Date: Sun, 3 Mar 2024 12:00:51 +0100 Subject: [PATCH] Add support for mulitple models --- src/bin/command/post-template.ts | 19 ++++---- src/bin/command/post.ts | 14 +++--- src/bin/question/questions.ts | 12 ++--- src/lib/llm.ts | 45 +++++++++++++++++ src/lib/parser.ts | 4 +- src/post-generator.ts | 84 ++++++++++++++++---------------- src/types.ts | 26 +++++++++- tests/test-api.spec.ts | 27 +++++++--- 8 files changed, 154 insertions(+), 77 deletions(-) create mode 100644 src/lib/llm.ts diff --git a/src/bin/command/post-template.ts b/src/bin/command/post-template.ts index b290e39..a8ef5b0 100644 --- a/src/bin/command/post-template.ts +++ b/src/bin/command/post-template.ts @@ -2,16 +2,16 @@ import fs from 'fs' import { Command } from 'commander' import { marked as markdownToHTML } from 'marked' import { PostTemplateGenerator } from '../../post-generator' -import { TemplatePost, TemplatePostPrompt } from 'src/types' -import { getFileExtension, isHTML, isMarkdown } from 'src/lib/template' -import { debug } from 'console' +import { DEFAULT_LLM, TemplatePostPrompt, getLLMs, llm } from 'src/types' +import { getFileExtension, isMarkdown } from 'src/lib/template' type Options = { debug: boolean debugapi: boolean apiKey: string templateFile: string - model: 'gpt-4-turbo-preview' | 'gpt-4' | 'gpt-3.5-turbo' + promptFolder: string + model: llm filename: string temperature: number frequencyPenalty: number @@ -33,12 +33,12 @@ export function buildPostTemplateCommands (program: Command) { }, {}) return { ...previous, ...obj } }, {}) - .option('-m, --model ', 'Set the LLM : "gpt-4-turbo-preview" | "gpt-4" | "gpt-3.5-turbo" (optional), gpt-4-turbo-preview by default') + .option('-m, --model ', `Set the LLM : ${getLLMs().join('| ')}`) + .option('-f, --filename ', 'Set the post file name (optional)') .option('-tt, --temperature ', 'Set the temperature (optional)') .option('-fp, --frequencypenalty ', 'Set the frequency penalty (optional)') .option('-pp, --presencepenalty ', 'Set the presence penalty (optional)') - .option('-lb, --logitbias ', 'Set the logit bias (optional)') .option('-d, --debug', 'Output extra debugging (optional)') .option('-da, --debugapi', 'Debug the api calls (optional)') .option('-k, --apiKey ', 'Set the OpenAI api key (optional, you can also set the OPENAI_API_KEY environment variable)') @@ -57,6 +57,7 @@ async function generatePost (options: Options) { const postPrompt: TemplatePostPrompt = { ...defaultPostPrompt, ...options + } const postGenerator = new PostTemplateGenerator(postPrompt) @@ -84,11 +85,9 @@ async function generatePost (options: Options) { function buildDefaultPostPrompt () { return { - model: 'gpt-4-turbo-preview', + model: DEFAULT_LLM, temperature: 0.8, frequencyPenalty: 0, - presencePenalty: 1, - logitBias: 0 - + presencePenalty: 1 } } diff --git a/src/bin/command/post.ts b/src/bin/command/post.ts index 69d5cd7..8f183b4 100644 --- a/src/bin/command/post.ts +++ b/src/bin/command/post.ts @@ -3,7 +3,7 @@ import { Command } from 'commander' import { marked as markdownToHTML } from 'marked' import { askQuestions } from '../question/questions' import { PostGenerator } from '../../post-generator' -import { Post, AutoPostPrompt } from 'src/types' +import { Post, AutoPostPrompt, llm, DEFAULT_LLM, getLLMs } from 'src/types' type Options = { interactive: boolean @@ -12,8 +12,9 @@ type Options = { apiKey: string templateFile: string language: string - model: 'gpt-4-turbo-preview' | 'gpt-4' | 'gpt-3.5-turbo' + model: llm filename: string + promptFolder: string topic: string country: string generate: boolean // generate the audience and intent @@ -29,7 +30,7 @@ export function buildPostCommands (program: Command) { .description('Generate a post in interactive or automatic mode') .option('-i, --interactive', 'Use interactive mode (CLI questions)') .option('-l, --language ', 'Set the language (optional), english by default') - .option('-m, --model ', 'Set the LLM : "gpt-4-turbo-preview" | "gpt-4" | "gpt-3.5-turbo" (optional), gpt-4-turbo-preview by default') + .option('-m, --model ', `Set the LLM : ${getLLMs().join('| ')}`) .option('-f, --filename ', 'Set the post file name (optional)') .option('-pf, --promptfolder ', 'Use custom prompt define in this folder (optional)') .option('-tp, --topic ', 'Set the post topic (optional)') @@ -39,7 +40,6 @@ export function buildPostCommands (program: Command) { .option('-tt, --temperature ', 'Set the temperature (optional)') .option('-fp, --frequencypenalty ', 'Set the frequency penalty (optional)') .option('-pp, --presencepenalty ', 'Set the presence penalty (optional)') - .option('-lb, --logitbias ', 'Set the logit bias (optional)') .option('-d, --debug', 'Output extra debugging (optional)') .option('-da, --debugapi', 'Debug the api calls (optional)') .option('-k, --apiKey ', 'Set the OpenAI api key (optional, you can also set the OPENAI_API_KEY environment variable)') @@ -98,13 +98,13 @@ function isInteractive (options : Options) { function buildDefaultPostPrompt () : AutoPostPrompt { return { - model: 'gpt-4-turbo-preview', + model: DEFAULT_LLM, language: 'english', withConclusion: true, temperature: 0.8, frequencyPenalty: 1, - presencePenalty: 1, - logitBias: 0 + presencePenalty: 1 + } } diff --git a/src/bin/question/questions.ts b/src/bin/question/questions.ts index 0a48b47..7dc39eb 100644 --- a/src/bin/question/questions.ts +++ b/src/bin/question/questions.ts @@ -1,6 +1,7 @@ import inquirer from 'inquirer' import inquirerPrompt from 'inquirer-autocomplete-prompt' import inquirerFileTreeSelection from 'inquirer-file-tree-selection-prompt' +import { DEFAULT_LLM, getLLMs } from 'src/types' inquirer.registerPrompt('autocomplete', inquirerPrompt) inquirer.registerPrompt('file-tree-selection', inquirerFileTreeSelection) @@ -11,7 +12,8 @@ const LANGUAGES = ['english', 'french', 'spanish', 'german', 'italian', 'russian 'slovak', 'croatian', 'ukrainian', 'slovene', 'estonian', 'latvian', 'lithuanian', 'chinese', 'hindi', 'arabic', 'japanese'] -const MODELS = ['gpt-4-turbo-preview', 'gpt-4', 'gpt-3.5-turbo'] +const MODELS = getLLMs() + const questions = [ { type: 'autocomplete', @@ -25,7 +27,7 @@ const questions = [ name: 'model', message: 'AI model ?', choices: MODELS, - default: 'gpt-4-turbo-preview' + default: DEFAULT_LLM }, { type: 'input', @@ -84,12 +86,6 @@ const questions = [ name: 'presencePenalty', message: 'Presence Penalty (-2/2) ?', default: 1 - }, - { - type: 'number', - name: 'logitBias', - message: 'Logit bias ?', - default: 0 } ] diff --git a/src/lib/llm.ts b/src/lib/llm.ts new file mode 100644 index 0000000..5bebbd9 --- /dev/null +++ b/src/lib/llm.ts @@ -0,0 +1,45 @@ +import { BaseChatModel } from '@langchain/core/language_models/chat_models' +import { ChatMistralAI } from '@langchain/mistralai' +import { ChatOpenAI } from '@langchain/openai' +import { BasePostPrompt } from '../types' + +export function buildLLM (postPrompt: BasePostPrompt, forJson: boolean = false): BaseChatModel { + switch (postPrompt.model) { + case 'gpt-4': + case 'gpt-4-turbo-preview': + return buildOpenAI(postPrompt, forJson) + case 'mistral-small-latest': + case 'mistral-medium-latest': + case 'mistral-large-latest': + return buildMistral(postPrompt, forJson) + + default: + return buildOpenAI(postPrompt, forJson) + } +} + +function buildOpenAI (postPrompt: BasePostPrompt, forJson: boolean = false) { + const llmParams = { + modelName: postPrompt.model.toString(), + temperature: postPrompt.temperature ?? 0.8, + frequencyPenalty: forJson ? 0 : postPrompt.frequencyPenalty ?? 1, + presencePenalty: forJson ? 0 : postPrompt.presencePenalty ?? 1, + verbose: postPrompt.debugapi, + openAIApiKey: postPrompt.apiKey + + } + return new ChatOpenAI(llmParams) +} + +function buildMistral (postPrompt: BasePostPrompt, forJson: boolean = false) { + const llmParams = { + modelName: postPrompt.model.toString(), + temperature: postPrompt.temperature ?? 0.8, + frequencyPenalty: forJson ? 0 : postPrompt.frequencyPenalty ?? 1, + presencePenalty: forJson ? 0 : postPrompt.presencePenalty ?? 1, + verbose: postPrompt.debugapi, + apiKey: postPrompt.apiKey + + } + return new ChatMistralAI(llmParams) +} diff --git a/src/lib/parser.ts b/src/lib/parser.ts index c4845ad..4e55a4e 100644 --- a/src/lib/parser.ts +++ b/src/lib/parser.ts @@ -7,7 +7,7 @@ import { isHTML, isMarkdown } from './template' const HeadingSchema: z.ZodSchema = z.object({ title: z.string(), keywords: z.array(z.string()).optional(), - headings: z.array(z.lazy(() => PostOutlineSchema)).optional() + headings: z.array(z.lazy(() => HeadingSchema)).optional() }) const PostOutlineSchema = z.object({ @@ -59,7 +59,7 @@ export class HTMLOutputParser extends BaseOutputParser { getFormatInstructions (): string { return ` - Your answer has to be only a HRML block. + Your answer has to be only a HTML block. The block has to delimited by \`\`\`html (beginning of the block) and \`\`\` (end of the block) ` } diff --git a/src/post-generator.ts b/src/post-generator.ts index 0fcb9d1..178488b 100644 --- a/src/post-generator.ts +++ b/src/post-generator.ts @@ -37,7 +37,8 @@ import { } from './lib/prompt' import { Template } from './lib/template' -import { log } from 'console' +import { buildLLM } from './lib/llm' +import { BaseChatModel } from '@langchain/core/language_models/chat_models' dotenv.config() const readFile = promisify(rd) @@ -52,8 +53,8 @@ const PARSER_INSTRUCTIONS_TAG = '\n{formatInstructions}\n' // - Conclusion (optional) // ----------------------------------------------------------------------------------------- export class PostGenerator { - private llm_json: ChatOpenAI - private llm_content: ChatOpenAI + private llm_json: BaseChatModel + private llm_content: BaseChatModel private memory : BufferMemory private log private promptFolder: string @@ -61,29 +62,18 @@ export class PostGenerator { public constructor (private postPrompt: AutoPostPrompt) { this.log = createLogger(postPrompt.debug ? 'debug' : 'info') - if (this.postPrompt.promptFolder) { + if (this.postPrompt.promptFolder && postPrompt.promptFolder !== '') { this.log.info('Use prompts from folder : ' + this.postPrompt.promptFolder) } - this.promptFolder = postPrompt.promptFolder ?? path.join(__dirname, DEFAULT_PROMPT_FOLDER) - - this.llm_content = new ChatOpenAI({ - modelName: postPrompt.model, - temperature: postPrompt.temperature ?? 0.8, - frequencyPenalty: postPrompt.frequencyPenalty ?? 0, - presencePenalty: postPrompt.presencePenalty ?? 1, - verbose: postPrompt.debugapi, - openAIApiKey: postPrompt.apiKey - }) + this.promptFolder = postPrompt.promptFolder && postPrompt.promptFolder !== '' + ? postPrompt.promptFolder + : path.join(__dirname, DEFAULT_PROMPT_FOLDER) + this.llm_content = buildLLM(postPrompt) as BaseChatModel // For the outline, we use a different setting without frequencyPenalty and presencePenalty // in order to avoid some json format issue - this.llm_json = new ChatOpenAI({ - modelName: postPrompt.model, - temperature: postPrompt.temperature ?? 0.8, - verbose: postPrompt.debugapi, - openAIApiKey: postPrompt.apiKey - }) + this.llm_json = buildLLM(postPrompt, true) this.memory = new BufferMemory({ returnMessages: true @@ -226,7 +216,7 @@ export class PostGenerator { { input: outlineMessage }, { output: this.postOutlineToMarkdown(outline) } ) - this.log.debug('OUTLINE :\n\n') + this.log.debug(' ----------------------OUTLINE ----------------------') this.log.debug(JSON.stringify(outline, null, 2)) return outline @@ -320,7 +310,17 @@ export class PostGenerator { keywords: heading.keywords?.join(', ') } - const content = await chain.invoke(inputVariables) + let content = await chain.invoke(inputVariables) + + if (content === '' || content === null) { + this.log.warn(`🤷🏻‍♂️ No content generated for heading : ${heading.title} with the model : ${this.postPrompt.model}`) + content = `🤷🏻‍♂️ No content generated with the model: ${this.postPrompt.model}` + } + + this.log.debug(' ---------------------- HEADING : ' + heading.title + '----------------------') + this.log.debug(content) + this.log.debug(' ---------------------- HEADING END ----------------------') + this.memory.saveContext( { input: `Write a content for the heading : ${heading.title}` }, { output: content } @@ -367,7 +367,17 @@ export class PostGenerator { language: this.postPrompt.language } - const content = await chain.invoke(inputVariables) + let content = await chain.invoke(inputVariables) + + if (content === '' || content === null) { + this.log.warn(`🤷🏻‍♂️ No content generated with the model : ${this.postPrompt.model}`) + content = `🤷🏻‍♂️ No content generated with the model: ${this.postPrompt.model}` + } + + this.log.debug(' ---------------------- CONTENT ----------------------') + this.log.debug(content) + this.log.debug(' ---------------------- CONTENT END ----------------------') + this.memory.saveContext( { input: memoryInput }, { output: content } @@ -432,8 +442,8 @@ export class PostGenerator { // A template is a file containing prompts that will be replaced by the content // ----------------------------------------------------------------------------------------- export class PostTemplateGenerator { - private llm_content: ChatOpenAI - private llm_json: ChatOpenAI + private llm_content: BaseChatModel + private llm_json: BaseChatModel private memory: BufferMemory private log private promptFolder: string @@ -445,24 +455,14 @@ export class PostTemplateGenerator { this.log.info('Use prompts from folder : ' + this.postPrompt.promptFolder) } - this.promptFolder = postPrompt.promptFolder ?? path.join(__dirname, DEFAULT_PROMPT_FOLDER) - - this.llm_content = new ChatOpenAI({ - modelName: postPrompt.model, - temperature: postPrompt.temperature ?? 0.8, - frequencyPenalty: postPrompt.frequencyPenalty ?? 0, - presencePenalty: postPrompt.presencePenalty ?? 0, - verbose: postPrompt.debugapi, - openAIApiKey: postPrompt.apiKey - }) - - this.llm_json = new ChatOpenAI({ - modelName: postPrompt.model, - temperature: postPrompt.temperature ?? 0.8, - verbose: postPrompt.debugapi, - openAIApiKey: postPrompt.apiKey + this.promptFolder = postPrompt.promptFolder && postPrompt.promptFolder !== '' + ? postPrompt.promptFolder + : path.join(__dirname, DEFAULT_PROMPT_FOLDER) - }) + this.llm_content = buildLLM(postPrompt) + // For the outline, we use a different setting without frequencyPenalty and presencePenalty + // in order to avoid some json format issue + this.llm_json = buildLLM(postPrompt, true) this.memory = new BufferMemory({ returnMessages: true diff --git a/src/types.ts b/src/types.ts index fbf1812..0ef8a1e 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,9 +1,31 @@ +export type llm = + 'gpt-4' | + 'gpt-4-turbo-preview' | + 'mistral-small-latest' | + 'mistral-medium-latest' | + 'mistral-large-latest' | + 'claude' | + 'groq' + +export const DEFAULT_LLM : llm = 'gpt-4-turbo-preview' + +export function getLLMs (): llm[] { + return [ + 'gpt-4', + 'gpt-4-turbo-preview', + 'mistral-small-latest', + 'mistral-medium-latest', + 'mistral-large-latest', + 'claude', + 'groq' + ] +} + export type BasePostPrompt = { - model: 'gpt-4-turbo-preview' | 'gpt-4' | 'gpt-3.5-turbo' + model: llm temperature?: number frequencyPenalty?: number presencePenalty?: number - logitBias?: number debug?: boolean debugapi?: boolean apiKey?: string diff --git a/tests/test-api.spec.ts b/tests/test-api.spec.ts index b039f44..2fefbcb 100644 --- a/tests/test-api.spec.ts +++ b/tests/test-api.spec.ts @@ -1,9 +1,8 @@ import { AutoPostPrompt, TemplatePostPrompt } from '../src/types' import { PostGenerator, PostTemplateGenerator } from '../src/post-generator' -import { log } from 'console' describe('API', () => { - it.skip('generates a post in automatic mode', async () => { + it.skip('generates a post in automatic mode - OPEN AI', async () => { try { const postPrompt: AutoPostPrompt = { language: 'french', @@ -14,9 +13,26 @@ describe('API', () => { const postGenerator = new PostGenerator(postPrompt) const post = await postGenerator.generate() expect(post).not.toBeNull() - log(post) + console.log(post) } catch (e) { - log(e) + console.log(e) + } + }, 300000) + it('generates a post in automatic mode - MISTRAL', async () => { + try { + const postPrompt: AutoPostPrompt = { + language: 'french', + model: 'mistral-large-latest', + topic: 'Comment devenir digital nomad ?', + promptFolder: './prompts', + debug: true + } + const postGenerator = new PostGenerator(postPrompt) + const post = await postGenerator.generate() + expect(post).not.toBeNull() + console.log(post) + } catch (e) { + console.log(e) } }, 300000) }) @@ -28,7 +44,6 @@ describe('API with a custom template', () => { temperature: 0.7, frequencyPenalty: 0.5, presencePenalty: 0.5, - logitBias: 0, templateFile: './tests/data/template-2.md', promptFolder: './prompts', debug: true, @@ -38,6 +53,6 @@ describe('API with a custom template', () => { const postGenerator = new PostTemplateGenerator(postPrompt) const post = await postGenerator.generate() expect(post).not.toBeNull() - log(post) + console.log(post) }, 300000) })