Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 56 additions & 31 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import { ChatOllama } from "@langchain/ollama";
import { ChatOpenAI } from "@langchain/openai";
import { ChatXAI } from "@langchain/xai";
import { Notice } from "obsidian";
import { GitHubCopilotProvider } from "./githubCopilotProvider";
import { CopilotChatModel } from "./githubCopilotChatModel";

type ChatConstructorType = {
new (config: any): any;
};

const CHAT_PROVIDER_CONSTRUCTORS = {
const CHAT_PROVIDER_CONSTRUCTORS: Partial<Record<ChatModelProviders, ChatConstructorType>> = {
[ChatModelProviders.OPENAI]: ChatOpenAI,
[ChatModelProviders.AZURE_OPENAI]: ChatOpenAI,
[ChatModelProviders.ANTHROPIC]: ChatAnthropic,
Expand All @@ -43,8 +45,6 @@ const CHAT_PROVIDER_CONSTRUCTORS = {
[ChatModelProviders.DEEPSEEK]: ChatDeepSeek,
} as const;

type ChatProviderConstructMap = typeof CHAT_PROVIDER_CONSTRUCTORS;

export default class ChatModelManager {
private static instance: ChatModelManager;
private static chatModel: BaseChatModel | null;
Expand All @@ -57,7 +57,7 @@ export default class ChatModelManager {
}
>;

private readonly providerApiKeyMap: Record<ChatModelProviders, () => string> = {
private readonly providerApiKeyMap: Partial<Record<ChatModelProviders, () => string>> = {
[ChatModelProviders.OPENAI]: () => getSettings().openAIApiKey,
[ChatModelProviders.GOOGLE]: () => getSettings().googleApiKey,
[ChatModelProviders.AZURE_OPENAI]: () => getSettings().azureOpenAIApiKey,
Expand Down Expand Up @@ -97,10 +97,16 @@ export default class ChatModelManager {
const isThinkingEnabled =
modelName.startsWith("claude-3-7-sonnet") || modelName.startsWith("claude-sonnet-4");

// For GitHub Copilot, streaming is not supported
const streaming =
customModel.provider === ChatModelProviders.GITHUB_COPILOT
? false
: (customModel.stream ?? true);

// Base config without temperature when thinking is enabled
const baseConfig: Omit<ModelConfig, "maxTokens" | "maxCompletionTokens" | "temperature"> = {
modelName: modelName,
streaming: customModel.stream ?? true,
streaming,
maxRetries: 3,
maxConcurrency: 3,
enableCors: customModel.enableCors,
Expand All @@ -111,9 +117,7 @@ export default class ChatModelManager {
(baseConfig as any).temperature = customModel.temperature ?? settings.temperature;
}

const providerConfig: {
[K in keyof ChatProviderConstructMap]: ConstructorParameters<ChatProviderConstructMap[K]>[0];
} = {
const providerConfig = {
[ChatModelProviders.OPENAI]: {
modelName: modelName,
openAIApiKey: await getDecryptedKey(customModel.apiKey || settings.openAIApiKey),
Expand Down Expand Up @@ -250,7 +254,7 @@ export default class ChatModelManager {
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
};
} as any;

const selectedProviderConfig =
providerConfig[customModel.provider as keyof typeof providerConfig] || {};
Expand Down Expand Up @@ -315,7 +319,7 @@ export default class ChatModelManager {
const constructor = this.getProviderConstructor(model);
const getDefaultApiKey = this.providerApiKeyMap[model.provider as ChatModelProviders];

const apiKey = model.apiKey || getDefaultApiKey();
const apiKey = model.apiKey || (getDefaultApiKey ? getDefaultApiKey() : "");
const modelKey = getModelKeyFromModel(model);
modelMap[modelKey] = {
hasApiKey: Boolean(model.apiKey || apiKey),
Expand All @@ -327,11 +331,9 @@ export default class ChatModelManager {
}

getProviderConstructor(model: CustomModel): ChatConstructorType {
const constructor: ChatConstructorType =
CHAT_PROVIDER_CONSTRUCTORS[model.provider as ChatModelProviders];
const constructor = CHAT_PROVIDER_CONSTRUCTORS[model.provider as ChatModelProviders];
if (!constructor) {
console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`);
throw new Error(`Unknown provider: ${model.provider} for model: ${model.name}`);
throw new Error(`No chat model constructor registered for provider: ${model.provider}`);
}
return constructor;
}
Expand All @@ -344,35 +346,45 @@ export default class ChatModelManager {
}

async setChatModel(model: CustomModel): Promise<void> {
const modelKey = getModelKeyFromModel(model);
try {
const modelInstance = await this.createModelInstance(model);
ChatModelManager.chatModel = modelInstance;
} catch (error) {
logError(error);
new Notice(`Error creating model: ${modelKey}`);
ChatModelManager.chatModel = await this.createModelInstance(model);
logInfo(`Chat model set to ${model.name}`);
} catch (e) {
logError("Failed to set chat model:", e);
new Notice(`Failed to set chat model: ${e.message}`);
ChatModelManager.chatModel = null;
}
}

async createModelInstance(model: CustomModel): Promise<BaseChatModel> {
// Create and return the appropriate model
// Validate model existence
if (!model) {
throw new Error("No model provided to createModelInstance.");
}

// Special handling for GitHub Copilot
if (model.provider === ChatModelProviders.GITHUB_COPILOT) {
const provider = new GitHubCopilotProvider();
return new CopilotChatModel({ provider, modelName: model.name });
}

// Validate model is enabled and has API key if required
const modelKey = getModelKeyFromModel(model);
const selectedModel = ChatModelManager.modelMap[modelKey];
const selectedModel = ChatModelManager.modelMap?.[modelKey];
if (!selectedModel) {
throw new Error(`No model found for: ${modelKey}`);
throw new Error(`Model '${model.name}' is not enabled or not found in the model map.`);
}
if (!selectedModel.hasApiKey) {
const errorMessage = `API key is not provided for the model: ${modelKey}.`;
new Notice(errorMessage);
throw new Error(errorMessage);
throw new Error(
`API key is missing for model '${model.name}'. Please check your API key settings.`
);
}

const modelConfig = await this.getModelConfig(model);
// Only now get the constructor
const AIConstructor = this.getProviderConstructor(model);
const config = await this.getModelConfig(model);

const newModelInstance = new selectedModel.AIConstructor({
...modelConfig,
});
return newModelInstance;
return new AIConstructor(config);
}

validateChatModel(chatModel: BaseChatModel): boolean {
Expand Down Expand Up @@ -427,6 +439,19 @@ export default class ChatModelManager {
}

async ping(model: CustomModel): Promise<boolean> {
if (model.provider === ChatModelProviders.GITHUB_COPILOT) {
const provider = new GitHubCopilotProvider();
const state = provider.getAuthState();
if (state.status === "authenticated") {
new Notice("GitHub Copilot is authenticated.");
return true;
} else {
new Notice(
"GitHub Copilot is not authenticated. Please set it up in the 'Basic' settings tab."
);
return false;
}
}
const tryPing = async (enableCors: boolean) => {
const modelToTest = { ...model, enableCors };
const modelConfig = await this.getModelConfig(modelToTest);
Expand Down
83 changes: 83 additions & 0 deletions src/LLMProviders/githubCopilotChatModel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import {
BaseChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { AIMessage, type BaseMessage, type MessageContent } from "@langchain/core/messages";
import { type ChatResult, ChatGeneration } from "@langchain/core/outputs";
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { GitHubCopilotProvider } from "./githubCopilotProvider";
import { extractTextFromChunk } from "@/utils";

export interface CopilotChatModelParams extends BaseChatModelParams {
provider: GitHubCopilotProvider;
modelName: string;
}

export class CopilotChatModel extends BaseChatModel {
lc_serializable = false;
lc_namespace = ["langchain", "chat_models", "copilot"];
private provider: GitHubCopilotProvider;
modelName: string;

constructor(fields: CopilotChatModelParams) {
super(fields);
this.provider = fields.provider;
this.modelName = fields.modelName;
}

_llmType(): string {
return "copilot-chat-model";
}

private _convertMessageType(messageType: string): string {
switch (messageType) {
case "human":
return "user";
case "ai":
return "assistant";
case "system":
return "system";
case "tool":
return "tool";
case "function":
return "function";
case "generic":
default:
return "user";
}
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const chatMessages = messages.map((m) => ({
role: this._convertMessageType(m._getType()),
content: extractTextFromChunk(m.content),
}));

const response = await this.provider.sendChatMessage(chatMessages, this.modelName);
const content = response.choices?.[0]?.message?.content || "";

const generation: ChatGeneration = {
text: content,
message: new AIMessage(content),
};

return {
generations: [generation],
llmOutput: {}, // add more details here if needed
};
}

/**
* A simple approximation: ~4 chars per token for English text
* This matches the fallback behavior in ChatModelManager.countTokens
*/
async getNumTokens(content: MessageContent): Promise<number> {
const text = typeof content === "string" ? content : JSON.stringify(content);
if (!text) return 0;
return Math.ceil(text.length / 4);
}
}
Loading