Skip to content

Commit 452be78

Browse files
committed
fix: github copilot mapping error, invalid cast that causes runtime errors, token calculation (rn approximation)
1 parent f1ee68c commit 452be78

File tree

2 files changed

+73
-21
lines changed

2 files changed

+73
-21
lines changed

src/LLMProviders/chatModelManager.ts

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@ import { err2String, isOSeriesModel, safeFetch, withSuppressedTokenWarnings } fr
1212
import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
1313
import { ChatAnthropic } from "@langchain/anthropic";
1414
import { ChatCohere } from "@langchain/cohere";
15-
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
16-
import { AIMessage } from "@langchain/core/messages";
17-
import { Runnable } from "@langchain/core/runnables";
15+
import {
16+
BaseChatModel,
17+
type BaseChatModelParams,
18+
} from "@langchain/core/language_models/chat_models";
19+
import { AIMessage, type BaseMessage, type MessageContent } from "@langchain/core/messages";
20+
import { type ChatResult, ChatGeneration } from "@langchain/core/outputs";
21+
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
1822
import { ChatDeepSeek } from "@langchain/deepseek";
1923
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
2024
import { ChatGroq } from "@langchain/groq";
@@ -24,36 +28,85 @@ import { ChatOpenAI } from "@langchain/openai";
2428
import { ChatXAI } from "@langchain/xai";
2529
import { Notice } from "obsidian";
2630
import { GitHubCopilotProvider } from "./githubCopilotProvider";
27-
import { ChatPromptValue } from "@langchain/core/prompt_values";
2831

29-
class CopilotRunnable extends Runnable {
32+
export interface CopilotChatModelParams extends BaseChatModelParams {
33+
provider: GitHubCopilotProvider;
34+
modelName: string;
35+
}
36+
37+
class CopilotChatModel extends BaseChatModel {
3038
lc_serializable = false;
3139
lc_namespace = ["langchain", "chat_models", "copilot"];
3240
private provider: GitHubCopilotProvider;
33-
private modelName: string;
41+
modelName: string;
42+
43+
constructor(fields: CopilotChatModelParams) {
44+
super(fields);
45+
this.provider = fields.provider;
46+
this.modelName = fields.modelName;
47+
}
3448

35-
constructor(provider: GitHubCopilotProvider, modelName: string) {
36-
super();
37-
this.provider = provider;
38-
this.modelName = modelName;
49+
_llmType(): string {
50+
return "copilot-chat-model";
3951
}
4052

41-
async invoke(input: ChatPromptValue, options?: any): Promise<any> {
42-
const messages = input.toChatMessages().map((m) => ({
43-
role: m._getType() === "human" ? "user" : "assistant",
53+
private _convertMessageType(messageType: string): string {
54+
switch (messageType) {
55+
case "human":
56+
return "user";
57+
case "ai":
58+
return "assistant";
59+
case "system":
60+
return "system";
61+
case "tool":
62+
return "tool";
63+
case "function":
64+
return "function";
65+
case "generic":
66+
default:
67+
return "user";
68+
}
69+
}
70+
71+
async _generate(
72+
messages: BaseMessage[],
73+
options: this["ParsedCallOptions"],
74+
runManager?: CallbackManagerForLLMRun
75+
): Promise<ChatResult> {
76+
const chatMessages = messages.map((m) => ({
77+
role: this._convertMessageType(m._getType()),
4478
content: m.content as string,
4579
}));
46-
const response = await this.provider.sendChatMessage(messages, this.modelName);
80+
81+
const response = await this.provider.sendChatMessage(chatMessages, this.modelName);
4782
const content = response.choices?.[0]?.message?.content || "";
48-
return new AIMessage(content);
83+
84+
const generation: ChatGeneration = {
85+
text: content,
86+
message: new AIMessage(content),
87+
};
88+
89+
return {
90+
generations: [generation],
91+
llmOutput: {}, // add more details here if needed
92+
};
93+
}
94+
95+
/**
96+
* A simple approximation: ~4 chars per token for English text
97+
* This matches the fallback behavior in ChatModelManager.countTokens
98+
*/
99+
async getNumTokens(content: MessageContent): Promise<number> {
100+
const text = typeof content === "string" ? content : JSON.stringify(content);
101+
if (!text) return 0;
102+
return Math.ceil(text.length / 4);
49103
}
50104
}
51105

52106
type ChatConstructorType = {
53107
new (config: any): any;
54108
};
55109

56-
// Placeholder for GitHub Copilot chat provider
57110
class ChatGitHubCopilot {
58111
private provider: GitHubCopilotProvider;
59112
constructor(config: any) {
@@ -420,9 +473,7 @@ export default class ChatModelManager {
420473
async createModelInstance(model: CustomModel): Promise<BaseChatModel> {
421474
if (model.provider === ChatModelProviders.GITHUB_COPILOT) {
422475
const provider = new GitHubCopilotProvider();
423-
const copilotRunnable = new CopilotRunnable(provider, model.name);
424-
// The type assertion is a bit of a hack, but it makes it work with the existing structure
425-
return copilotRunnable as unknown as BaseChatModel;
476+
return new CopilotChatModel({ provider, modelName: model.name });
426477
}
427478

428479
const AIConstructor = this.getProviderConstructor(model);

src/LLMProviders/githubCopilotProvider.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ export class GitHubCopilotProvider {
120120
if (res.status !== 200) throw new Error("Failed to get Copilot token");
121121
const data = res.json;
122122
this.authState.copilotToken = data.token;
123-
this.authState.copilotTokenExpiresAt =
124-
Date.now() + (data.expires_at ? data.expires_at * 1000 : 3600 * 1000);
123+
this.authState.copilotTokenExpiresAt = data.expires_at
124+
? data.expires_at * 1000
125+
: Date.now() + 3600 * 1000;
125126
this.authState.status = "authenticated";
126127
// Persist Copilot token and expiration
127128
updateSetting("copilotToken", data.token);

0 commit comments

Comments
 (0)