Skip to content

Commit

Permalink
Only ask questions after claude request completes
Browse files Browse the repository at this point in the history
This prevents issue where claude request takes longer than auth and breaks
on first ask.

Signed-off-by: Tomas Slusny <[email protected]>
  • Loading branch information
deathbeam committed Oct 30, 2024
1 parent 8260483 commit 62a02a2
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,8 @@ end

--- Check if the model can stream
--- @param model_name string: The model name to check
local function is_o1(model_name)
if vim.startswith(model_name, 'o1') then
return true
end
return false
local function can_stream(model_name)
return not vim.startswith(model_name, 'o1')
end

local function generate_ask_request(
Expand All @@ -203,7 +200,7 @@ local function generate_ask_request(
local messages = {}

local system_role = 'system'
if is_o1(model) then
if not can_stream(model) then
system_role = 'user'
end

Expand Down Expand Up @@ -237,13 +234,7 @@ local function generate_ask_request(
role = 'user',
})

if is_o1(model) then
return {
messages = messages,
stream = false,
model = model,
}
else
if can_stream(model) then
return {
intent = true,
model = model,
Expand All @@ -253,6 +244,12 @@ local function generate_ask_request(
top_p = 1,
messages = messages,
}
else
return {
messages = messages,
stream = false,
model = model,
}
end
end

Expand Down Expand Up @@ -365,11 +362,13 @@ function Copilot:with_auth(on_done, on_error)
end
end

function Copilot:enable_claude()
if claude_enabled then
return
end
function Copilot:with_claude(on_done, on_error)
self:with_auth(function()
if claude_enabled then
on_done()
return
end

local url = 'https://api.githubcopilot.com/models/claude-3.5-sonnet/policy'
local headers = generate_headers(self.token.token, self.sessionid, self.machineid)
curl.post(url, {
Expand All @@ -380,16 +379,24 @@ function Copilot:enable_claude()
on_error = function(err)
err = 'Failed to enable Claude: ' .. vim.inspect(err)
log.error(err)
if on_error then
on_error(err)
end
end,
body = temp_file('{"state": "enabled"}'),
callback = function(response)
if response.status ~= 200 then
local msg = 'Failed to enable Claude: ' .. tostring(response.status)
log.error(msg)
if on_error then
on_error(msg)
end
return
end

claude_enabled = true
log.info('Claude enabled')
on_done()
end,
})
end)
Expand Down Expand Up @@ -438,6 +445,7 @@ function Copilot:ask(prompt, opts)
current_count = current_count + tiktoken.count(system_prompt)
current_count = current_count + tiktoken.count(selection_message)

-- Limit the number of files to send
if #embeddings_message.files > 0 then
local filtered_files = {}
current_count = current_count + tiktoken.count(embeddings_message.header)
Expand All @@ -451,10 +459,6 @@ function Copilot:ask(prompt, opts)
embeddings_message.files = filtered_files
end

if vim.startswith(model, 'claude') then
self:enable_claude()
end

local url = 'https://api.githubcopilot.com/chat/completions'
local body = vim.json.encode(
generate_ask_request(
Expand All @@ -476,8 +480,8 @@ function Copilot:ask(prompt, opts)

local errored = false
local full_response = ''
---@type fun(err: string, line: string)?
local stream_func = function(err, line)

local function stream_func(err, line)
if not line or errored then
return
end
Expand Down Expand Up @@ -539,12 +543,8 @@ function Copilot:ask(prompt, opts)
-- Collect full response incrementally so we can insert it to history later
full_response = full_response .. content
end
if is_o1(model) then
stream_func = nil
end

---@type fun(response: table)?
local nonstream_callback = function(response)
local function callback_func(response)
if response.status ~= 200 then
local err = 'Failed to get response: ' .. tostring(response.status)
log.error(err)
Expand Down Expand Up @@ -585,11 +585,13 @@ function Copilot:ask(prompt, opts)
})
end

if not is_o1(model) then
nonstream_callback = nil
local is_stream = can_stream(model)
local with_auth = self.with_auth
if vim.startswith(model, 'claude') then
with_auth = self.with_claude
end

self:with_auth(function()
with_auth(self, function()
local headers = generate_headers(self.token.token, self.sessionid, self.machineid)
self.current_job = curl
.post(url, {
Expand All @@ -598,15 +600,15 @@ function Copilot:ask(prompt, opts)
body = temp_file(body),
proxy = self.proxy,
insecure = self.allow_insecure,
callback = nonstream_callback,
callback = (not is_stream) and callback_func or nil,
stream = is_stream and stream_func or nil,
on_error = function(err)
err = 'Failed to get response: ' .. vim.inspect(err)
log.error(err)
if self.current_job and on_error then
on_error(err)
end
end,
stream = stream_func,
})
:after(function()
self.current_job = nil
Expand Down

0 comments on commit 62a02a2

Please sign in to comment.