diff --git a/lua/CopilotChat/copilot.lua b/lua/CopilotChat/copilot.lua index bc0682d..6ed40a5 100644 --- a/lua/CopilotChat/copilot.lua +++ b/lua/CopilotChat/copilot.lua @@ -32,17 +32,16 @@ ---@field save fun(self: CopilotChat.Copilot, name: string, path: string):nil ---@field load fun(self: CopilotChat.Copilot, name: string, path: string):table ---@field running fun(self: CopilotChat.Copilot):boolean ----@field select_model fun(self: CopilotChat.Copilot, callback: fun(table):nil):nil +---@field list_models fun(self: CopilotChat.Copilot, callback: fun(table):nil):nil local log = require('plenary.log') local curl = require('plenary.curl') +local prompts = require('CopilotChat.prompts') +local tiktoken = require('CopilotChat.tiktoken') local utils = require('CopilotChat.utils') local class = utils.class local join = utils.join local temp_file = utils.temp_file -local prompts = require('CopilotChat.prompts') -local tiktoken = require('CopilotChat.tiktoken') -local max_tokens = 8192 local timeout = 30000 local version_headers = { ['editor-version'] = 'Neovim/' @@ -54,7 +53,6 @@ local version_headers = { ['editor-plugin-version'] = 'CopilotChat.nvim/2.0.0', ['user-agent'] = 'CopilotChat.nvim/2.0.0', } -local claude_enabled = false local function uuid() local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' @@ -302,7 +300,8 @@ local Copilot = class(function(self, proxy, allow_insecure) self.sessionid = nil self.machineid = machine_id() self.current_job = nil - self.models_cache = nil + self.models = nil + self.clause_enabled = false end) function Copilot:with_auth(on_done, on_error) @@ -362,63 +361,107 @@ function Copilot:with_auth(on_done, on_error) end end -function Copilot:with_claude(on_done, on_error) - self:with_auth(function() - if claude_enabled then +function Copilot:with_models(on_done, on_error) + if self.models ~= nil then + on_done() + return + end + + local url = 'https://api.githubcopilot.com/models' + local headers = generate_headers(self.token.token, self.sessionid, self.machineid) + curl.get(url, { + timeout = timeout, + headers = headers, + proxy = self.proxy, + insecure = self.allow_insecure, + on_error = function(err) + err = 'Failed to get response: ' .. vim.inspect(err) + log.error(err) + if on_error then + on_error(err) + end + end, + callback = function(response) + if response.status ~= 200 then + local msg = 'Failed to fetch models: ' .. tostring(response.status) + log.error(msg) + if on_error then + on_error(msg) + end + return + end + + -- Find chat models + local models = vim.json.decode(response.body)['data'] + local out = {} + for _, model in ipairs(models) do + if model['capabilities']['type'] == 'chat' then + out[model['id']] = model + end + end + + log.info('Models fetched') + self.models = out on_done() - return - end + end, + }) +end - local business_check = 'cannot enable policy inline for business users' - local business_msg = - 'Claude is probably enabled (for business users needs to be enabled manually).' +function Copilot:with_claude(model, on_done, on_error) + if self.claude_enabled or not vim.startswith(model, 'claude') then + on_done() + return + end - local url = 'https://api.githubcopilot.com/models/claude-3.5-sonnet/policy' - local headers = generate_headers(self.token.token, self.sessionid, self.machineid) - curl.post(url, { - timeout = timeout, - headers = headers, - proxy = self.proxy, - insecure = self.allow_insecure, - on_error = function(err) - err = 'Failed to enable Claude: ' .. vim.inspect(err) - if string.find(err, business_check) then - claude_enabled = true + local business_check = 'cannot enable policy inline for business users' + local business_msg = + 'Claude is probably enabled (for business users needs to be enabled manually).' + + local url = 'https://api.githubcopilot.com/models/claude-3.5-sonnet/policy' + local headers = generate_headers(self.token.token, self.sessionid, self.machineid) + curl.post(url, { + timeout = timeout, + headers = headers, + proxy = self.proxy, + insecure = self.allow_insecure, + on_error = function(err) + err = 'Failed to enable Claude: ' .. vim.inspect(err) + if string.find(err, business_check) then + self.claude_enabled = true + log.info(business_msg) + on_done() + return + end + + log.error(err) + if on_error then + on_error(err) + end + end, + body = temp_file('{"state": "enabled"}'), + callback = function(response) + if response.status ~= 200 then + if string.find(tostring(response.body), business_check) then + self.claude_enabled = true log.info(business_msg) on_done() return end - log.error(err) - if on_error then - on_error(err) - end - end, - body = temp_file('{"state": "enabled"}'), - callback = function(response) - if response.status ~= 200 then - if string.find(tostring(response.body), business_check) then - claude_enabled = true - log.info(business_msg) - on_done() - return - end - - local msg = 'Failed to enable Claude: ' .. tostring(response.status) + local msg = 'Failed to enable Claude: ' .. tostring(response.status) - log.error(msg) - if on_error then - on_error(msg) - end - return + log.error(msg) + if on_error then + on_error(msg) end + return + end - claude_enabled = true - log.info('Claude enabled') - on_done() - end, - }) - end) + self.claude_enabled = true + log.info('Claude enabled') + on_done() + end, + }) end --- Ask a question to Copilot @@ -453,235 +496,176 @@ function Copilot:ask(prompt, opts) self:stop() end - local selection_message = - generate_selection_message(filename, filetype, start_row, end_row, selection) - local embeddings_message = generate_embeddings_message(embeddings) - - -- Count tokens - self.token_count = self.token_count + tiktoken.count(prompt) - - local current_count = 0 - current_count = current_count + tiktoken.count(system_prompt) - current_count = current_count + tiktoken.count(selection_message) - - -- Limit the number of files to send - if #embeddings_message.files > 0 then - local filtered_files = {} - current_count = current_count + tiktoken.count(embeddings_message.header) - for _, file in ipairs(embeddings_message.files) do - local file_count = current_count + tiktoken.count(file) - if file_count + self.token_count < max_tokens then - current_count = file_count - table.insert(filtered_files, file) - end - end - embeddings_message.files = filtered_files - end - - local url = 'https://api.githubcopilot.com/chat/completions' - local body = vim.json.encode( - generate_ask_request( - self.history, - prompt, - embeddings_message, - selection_message, - system_prompt, - model, - temperature - ) - ) - - -- Add the prompt to history after we have encoded the request - table.insert(self.history, { - content = prompt, - role = 'user', - }) - - local errored = false - local full_response = '' - - local function stream_func(err, line) - if not line or errored then - return - end + self:with_auth(function() + self:with_models(function() + local capabilities = self.models[model] and self.models[model].capabilities + or { limits = { max_prompt_tokens = 8192 }, tokenizer = 'cl100k_base' } + local max_tokens = capabilities.limits.max_prompt_tokens -- FIXME: Is max_prompt_tokens the right limit? + local tokenizer = capabilities.tokenizer + log.debug('Max tokens: ' .. max_tokens) + log.debug('Tokenizer: ' .. tokenizer) + + local selection_message = + generate_selection_message(filename, filetype, start_row, end_row, selection) + local embeddings_message = generate_embeddings_message(embeddings) + + tiktoken.load(tokenizer, function() + -- Count tokens + self.token_count = self.token_count + tiktoken.count(prompt) + local current_count = 0 + current_count = current_count + tiktoken.count(system_prompt) + current_count = current_count + tiktoken.count(selection_message) + + -- Limit the number of files to send + if #embeddings_message.files > 0 then + local filtered_files = {} + current_count = current_count + tiktoken.count(embeddings_message.header) + for _, file in ipairs(embeddings_message.files) do + local file_count = current_count + tiktoken.count(file) + if file_count + self.token_count < max_tokens then + current_count = file_count + table.insert(filtered_files, file) + end + end + embeddings_message.files = filtered_files + end - if err or vim.startswith(line, '{"error"') then - err = 'Failed to get response: ' .. (err and vim.inspect(err) or line) - errored = true - log.error(err) - if self.current_job and on_error then - on_error(err) - end - return - end + local url = 'https://api.githubcopilot.com/chat/completions' + local body = vim.json.encode( + generate_ask_request( + self.history, + prompt, + embeddings_message, + selection_message, + system_prompt, + model, + temperature + ) + ) - line = line:gsub('data: ', '') - if line == '' then - return - elseif line == '[DONE]' then - log.trace('Full response: ' .. full_response) - self.token_count = self.token_count + tiktoken.count(full_response) + -- Add the prompt to history after we have encoded the request + table.insert(self.history, { + content = prompt, + role = 'user', + }) - if self.current_job and on_done then - on_done(full_response, self.token_count + current_count) - end + local errored = false + local full_response = '' - table.insert(self.history, { - content = full_response, - role = 'assistant', - }) - return - end + local function stream_func(err, line) + if not line or errored then + return + end - local ok, content = pcall(vim.json.decode, line, { - luanil = { - object = true, - array = true, - }, - }) + if err or vim.startswith(line, '{"error"') then + err = 'Failed to get response: ' .. (err and vim.inspect(err) or line) + errored = true + log.error(err) + if self.current_job and on_error then + on_error(err) + end + return + end - if not ok then - err = 'Failed to parse response: ' .. vim.inspect(content) .. '\n' .. line - log.error(err) - return - end + line = line:gsub('data: ', '') + if line == '' then + return + elseif line == '[DONE]' then + log.trace('Full response: ' .. full_response) + self.token_count = self.token_count + tiktoken.count(full_response) + + if self.current_job and on_done then + on_done(full_response, self.token_count + current_count) + end + + table.insert(self.history, { + content = full_response, + role = 'assistant', + }) + return + end - if not content.choices or #content.choices == 0 then - return - end + local ok, content = pcall(vim.json.decode, line, { + luanil = { + object = true, + array = true, + }, + }) - content = content.choices[1].delta.content - if not content then - return - end + if not ok then + err = 'Failed to parse response: ' .. vim.inspect(content) .. '\n' .. line + log.error(err) + return + end - if self.current_job and on_progress then - on_progress(content) - end + if not content.choices or #content.choices == 0 then + return + end - -- Collect full response incrementally so we can insert it to history later - full_response = full_response .. content - end + local choice = content.choices[1] + local is_full = choice.message ~= nil + content = is_full and choice.message.content or choice.delta.content - local function callback_func(response) - if response.status ~= 200 then - local err = 'Failed to get response: ' .. tostring(response.status) - log.error(err) - if on_error then - on_error(err) - end - return - end + if not content then + return + end - local ok, content = pcall(vim.json.decode, response.body, { - luanil = { - object = true, - array = true, - }, - }) + if self.current_job and on_progress then + on_progress(content) + end - if not ok then - local err = 'Failed to parse response: ' .. vim.inspect(content) .. '\n' .. response.body - log.error(err) - if on_error then - on_error(err) - end - return - end + if is_full then + log.trace('Full response: ' .. content) + self.token_count = self.token_count + tiktoken.count(content) - full_response = content.choices[1].message.content - if on_progress then - on_progress(full_response) - end - self.token_count = self.token_count + tiktoken.count(full_response) - if on_done then - on_done(full_response, self.token_count + current_count) - end + if self.current_job and on_done then + on_done(content, self.token_count + current_count) + end - table.insert(self.history, { - content = full_response, - role = 'assistant', - }) - end + table.insert(self.history, { + content = full_response, + role = 'assistant', + }) + return + end - local is_stream = can_stream(model) - local with_auth = self.with_auth - if vim.startswith(model, 'claude') then - with_auth = self.with_claude - end + -- Collect full response incrementally so we can insert it to history later + full_response = full_response .. content + end - with_auth(self, function() - local headers = generate_headers(self.token.token, self.sessionid, self.machineid) - self.current_job = curl - .post(url, { - timeout = timeout, - headers = headers, - body = temp_file(body), - proxy = self.proxy, - insecure = self.allow_insecure, - callback = (not is_stream) and callback_func or nil, - stream = is_stream and stream_func or nil, - on_error = function(err) - err = 'Failed to get response: ' .. vim.inspect(err) - log.error(err) - if self.current_job and on_error then - on_error(err) - end - end, - }) - :after(function() - self.current_job = nil + self:with_claude(model, function() + self.current_job = curl + .post(url, { + timeout = timeout, + headers = generate_headers(self.token.token, self.sessionid, self.machineid), + body = temp_file(body), + proxy = self.proxy, + insecure = self.allow_insecure, + stream = stream_func, + on_error = function(err) + err = 'Failed to get response: ' .. vim.inspect(err) + log.error(err) + if self.current_job and on_error then + on_error(err) + end + end, + }) + :after(function() + self.current_job = nil + end) + end, on_error) end) + end, on_error) end, on_error) end ---- Fetch & allow model selection +--- List available models ---@param callback fun(table):nil -function Copilot:select_model(callback) - if self.models_cache ~= nil then - callback(self.models_cache) - return - end - - local url = 'https://api.githubcopilot.com/models' +function Copilot:list_models(callback) self:with_auth(function() - local headers = generate_headers(self.token.token, self.sessionid, self.machineid) - curl.get(url, { - timeout = timeout, - headers = headers, - proxy = self.proxy, - insecure = self.allow_insecure, - on_error = function(err) - err = 'Failed to get response: ' .. vim.inspect(err) - log.error(err) - end, - callback = function(response) - if response.status ~= 200 then - local msg = 'Failed to fetch models: ' .. tostring(response.status) - log.error(msg) - return - end - - local models = vim.json.decode(response.body)['data'] - local selections = {} - for _, model in ipairs(models) do - if model['capabilities']['type'] == 'chat' then - table.insert(selections, model['version']) - end - end - -- Remove duplicates from selection - local hash = {} - selections = vim.tbl_filter(function(model) - if not hash[model] then - hash[model] = true - return true - end - return false - end, selections) - self.models_cache = selections - callback(self.models_cache) - end, - }) + self:with_models(function() + callback(vim.tbl_keys(self.models)) + end) end) end diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index c5d6c2a..309fcdd 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -6,7 +6,6 @@ local Overlay = require('CopilotChat.overlay') local context = require('CopilotChat.context') local prompts = require('CopilotChat.prompts') local debuginfo = require('CopilotChat.debuginfo') -local tiktoken = require('CopilotChat.tiktoken') local utils = require('CopilotChat.utils') local M = {} @@ -351,7 +350,7 @@ end --- Select a Copilot GPT model. function M.select_model() - state.copilot:select_model(function(models) + state.copilot:list_models(function(models) vim.schedule(function() vim.ui.select(models, { prompt = 'Select a model', @@ -445,7 +444,7 @@ function M.ask(prompt, config, source) vim.schedule(function() append('\n\n' .. config.question_header .. config.separator .. '\n\n', config) state.response = response - if tiktoken.available() and token_count and token_count > 0 then + if token_count and token_count > 0 then state.chat:finish(token_count .. ' tokens used') else state.chat:finish() @@ -606,9 +605,9 @@ function M.setup(config) if state.copilot then state.copilot:stop() else - tiktoken.setup((config and config.model) or nil) debuginfo.setup() end + state.copilot = Copilot(M.config.proxy, M.config.allow_insecure) M.debug(M.config.debug) diff --git a/lua/CopilotChat/tiktoken.lua b/lua/CopilotChat/tiktoken.lua index 94c76be..38dc54d 100644 --- a/lua/CopilotChat/tiktoken.lua +++ b/lua/CopilotChat/tiktoken.lua @@ -1,11 +1,10 @@ local curl = require('plenary.curl') +local log = require('plenary.log') local tiktoken_core = nil +local current_tokenizer = nil ----Get the path of the cache directory ----@param fname string ----@return string local function get_cache_path(fname) - vim.fn.mkdir(vim.fn.stdpath('cache'), 'p') + vim.fn.mkdir(tostring(vim.fn.stdpath('cache')), 'p') return vim.fn.stdpath('cache') .. '/' .. fname end @@ -20,13 +19,11 @@ local function file_exists(name) end --- Load tiktoken data from cache or download it -local function load_tiktoken_data(done, model) - local tiktoken_url = 'https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken' - -- If model is gpt-4o, use o200k_base.tiktoken - if model ~= nil and vim.startswith(model, 'gpt-4o') then - tiktoken_url = 'https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken' - end - -- Take filename after the last slash of the url +local function load_tiktoken_data(done, tokenizer) + local tiktoken_url = 'https://openaipublic.blob.core.windows.net/encodings/' + .. tokenizer + .. '.tiktoken' + log.info('Downloading tiktoken data from ' .. tiktoken_url) local cache_path = get_cache_path(tiktoken_url:match('.+/(.+)')) local async @@ -48,25 +45,34 @@ end local M = {} ----@param model string|nil -function M.setup(model) +function M.load(tokenizer, on_done) + if tokenizer == current_tokenizer then + on_done() + return + end + local ok, core = pcall(require, 'tiktoken_core') if not ok then + on_done() return end - load_tiktoken_data(function(path) - local special_tokens = {} - special_tokens['<|endoftext|>'] = 100257 - special_tokens['<|fim_prefix|>'] = 100258 - special_tokens['<|fim_middle|>'] = 100259 - special_tokens['<|fim_suffix|>'] = 100260 - special_tokens['<|endofprompt|>'] = 100276 - local pat_str = - "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" - core.new(path, special_tokens, pat_str) - tiktoken_core = core - end, model) + vim.schedule(function() + load_tiktoken_data(function(path) + local special_tokens = {} + special_tokens['<|endoftext|>'] = 100257 + special_tokens['<|fim_prefix|>'] = 100258 + special_tokens['<|fim_middle|>'] = 100259 + special_tokens['<|fim_suffix|>'] = 100260 + special_tokens['<|endofprompt|>'] = 100276 + local pat_str = + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + core.new(path, special_tokens, pat_str) + tiktoken_core = core + current_tokenizer = tokenizer + on_done() + end, tokenizer) + end) end function M.available()