diff --git a/src/components/Chat/Chat.tsx b/src/components/Chat/Chat.tsx index 494f79fb7..eba622e48 100644 --- a/src/components/Chat/Chat.tsx +++ b/src/components/Chat/Chat.tsx @@ -359,8 +359,10 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { if (getCurrentPageName() != 'gpt4') { homeDispatch({ field: 'isRetrievalLoading', value: true }) // Extract text from all user messages in the conversation - const token_limit = - OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit + // CHeck models from home context + const token_limit = selectedConversation.model.tokenLimit + // const token_limit = + // OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit // ! DISABLE MQR FOR NOW -- too unreliable // const useMQRetrieval = localStorage.getItem('UseMQRetrieval') === 'true' diff --git a/src/components/Chat/ModelSelect.tsx b/src/components/Chat/ModelSelect.tsx index 37f4a4c8a..e18a3cd39 100644 --- a/src/components/Chat/ModelSelect.tsx +++ b/src/components/Chat/ModelSelect.tsx @@ -1,7 +1,6 @@ import { IconChevronDown, IconExternalLink } from '@tabler/icons-react' import { useContext } from 'react' import { useMediaQuery } from '@mantine/hooks' -import { type OpenAIModel } from '@/types/openai' import HomeContext from '~/pages/api/home/home.context' import { montserrat_heading, montserrat_paragraph } from 'fonts' import { Input, Select, Title } from '@mantine/core' @@ -33,7 +32,7 @@ export const ModelSelect = React.forwardRef( selectedConversation && handleUpdateConversation(selectedConversation, { key: 'model', - value: model as OpenAIModel, + value: model, }) } diff --git a/src/pages/api/home/home.state.tsx b/src/pages/api/home/home.state.tsx index f239a3f72..a6f910902 100644 --- a/src/pages/api/home/home.state.tsx +++ b/src/pages/api/home/home.state.tsx @@ -4,6 +4,7 @@ import { FolderInterface } from '@/types/folder' import { OpenAIModel, OpenAIModelID } from '@/types/openai' import { PluginKey } from '@/types/plugin' import { Prompt } from '@/types/prompt' +import { SupportedModels } from '~/types/LLMProvider' export interface HomeInitialState { apiKey: string @@ -12,7 +13,7 @@ export interface HomeInitialState { lightMode: 'light' | 'dark' messageIsStreaming: boolean modelError: ErrorMessage | null - models: OpenAIModel[] + models: SupportedModels selectedModel: OpenAIModel | null folders: FolderInterface[] conversations: Conversation[] diff --git a/src/pages/api/models.ts b/src/pages/api/models.ts index 74e815c97..e95ae51dd 100644 --- a/src/pages/api/models.ts +++ b/src/pages/api/models.ts @@ -8,7 +8,7 @@ import { import { OpenAIModel, OpenAIModelID, OpenAIModels } from '@/types/openai' import { decrypt, isEncrypted } from '~/utils/crypto' -import { LLMProvider, ProviderNames } from '~/types/LLMProviderKeys' +import { LLMProvider, ProviderNames } from '~/types/LLMProvider' import { getOllamaModels } from '~/utils/modelProviders/ollama' export const config = { @@ -117,7 +117,6 @@ const handler = async (req: Request): Promise => { return { id: model.id, name: OpenAIModels[value].name, - maxLength: OpenAIModels[value].maxLength, tokenLimit: OpenAIModels[value].tokenLimit, } } @@ -126,9 +125,13 @@ const handler = async (req: Request): Promise => { }) .filter((model): model is OpenAIModel => model !== undefined) - // console.log('Final list of Models: ', models) + const finalModels = [...models, ...ollamaModels] - return new Response(JSON.stringify(models), { status: 200 }) + console.log('OpenAI Models: ', models) + console.log('Ollama Models: ', ollamaModels) + console.log('FInal combined: ', finalModels) + + return new Response(JSON.stringify(finalModels), { status: 200 }) } catch (error) { console.error(error) return new Response('Error', { status: 500 }) diff --git a/src/types/LLMProviderKeys.ts b/src/types/LLMProvider.ts similarity index 60% rename from src/types/LLMProviderKeys.ts rename to src/types/LLMProvider.ts index 39b3b6903..239ba0c3b 100644 --- a/src/types/LLMProviderKeys.ts +++ b/src/types/LLMProvider.ts @@ -1,4 +1,4 @@ -import { OllamaModel } from './OllamaProvider' +import { OllamaModel } from '~/utils/modelProviders/ollama' import { OpenAIModel } from './openai' export enum ProviderNames { @@ -6,10 +6,12 @@ export enum ProviderNames { OpenAI = 'OpenAI', } +export type SupportedModels = OllamaModel[] | OpenAIModel[] + export interface LLMProvider { provider: ProviderNames enabled: boolean baseUrl: string apiKey?: string - models?: OllamaModel[] | OpenAIModel[] + models?: SupportedModels } diff --git a/src/types/OllamaProvider.ts b/src/types/OllamaProvider.ts deleted file mode 100644 index 444f0441f..000000000 --- a/src/types/OllamaProvider.ts +++ /dev/null @@ -1,4 +0,0 @@ -export interface OllamaModel { - name: string - parameterSize: string -} \ No newline at end of file diff --git a/src/types/openai.ts b/src/types/openai.ts index 59c37e2a0..e777b282d 100644 --- a/src/types/openai.ts +++ b/src/types/openai.ts @@ -3,7 +3,6 @@ export interface OpenAIModel { id: string name: string - maxLength: number // maximum length of a message in characters... should deprecate tokenLimit: number } @@ -27,25 +26,21 @@ export const OpenAIModels: Record = { [OpenAIModelID.GPT_3_5]: { id: OpenAIModelID.GPT_3_5, name: 'GPT-3.5 (16k)', - maxLength: 12000, tokenLimit: 16385, }, [OpenAIModelID.GPT_4]: { id: OpenAIModelID.GPT_4, name: 'GPT-4 (8k)', - maxLength: 24000, tokenLimit: 8192, }, [OpenAIModelID.GPT_4_Turbo]: { id: OpenAIModelID.GPT_4_Turbo, name: 'GPT-4 Turbo (128k)', - maxLength: 24000, tokenLimit: 128000, }, [OpenAIModelID.GPT_4o]: { id: OpenAIModelID.GPT_4o, name: 'GPT-4o (128k)', - maxLength: 24000, tokenLimit: 128000, }, @@ -53,19 +48,16 @@ export const OpenAIModels: Record = { [OpenAIModelID.GPT_4_AZURE]: { id: OpenAIModelID.GPT_4_AZURE, name: 'GPT-4 Turbo (128k)', - maxLength: 24000, tokenLimit: 128000, }, [OpenAIModelID.GPT_4_HACKATHON]: { id: OpenAIModelID.GPT_4_HACKATHON, name: 'GPT-4 Hackathon', - maxLength: 24000, tokenLimit: 128000, }, [OpenAIModelID.GPT_4_AZURE_04_09]: { id: OpenAIModelID.GPT_4_AZURE_04_09, name: 'GPT-4 Turbo 0409 (128k)', - maxLength: 24000, tokenLimit: 128000, }, } diff --git a/src/utils/modelProviders/ollama.ts b/src/utils/modelProviders/ollama.ts index 23103908a..6e8e96572 100644 --- a/src/utils/modelProviders/ollama.ts +++ b/src/utils/modelProviders/ollama.ts @@ -1,20 +1,23 @@ export interface OllamaModel { + id: string name: string parameterSize: string + tokenLimit: number } export const getOllamaModels = async () => { - console.log('In ollama GET endpoint') const response = await fetch('https://ollama.ncsa.ai/api/ps') if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`) } - const data = await response.json() + const ollamaModels: OllamaModel[] = data.models.map((model: any) => { return { + id: model.name, name: model.name, - parameterSize: model.parameter_size, + parameterSize: model.details.parameter_size, + tokenLimit: 4096, } as OllamaModel })