Skip to content

Commit

Permalink
Support A/B Compiler Arguments Traits to Completions Prompt
Browse files Browse the repository at this point in the history
- Depends on cpptools' update to provide ProjectContextResult.
- Added two new traits
  - compilerArguments: a list of compiler command arguments that could affect Copilot generating completions.
  - compilerUserDefines: a list of compiler command defines that could affect Copilot generating completions. Macro references are used to exclude the ones that are not relavent.
- A/B Experimental flags
  - copilotcppTraits: boolean flag to enable cpp traits
  - copilotcppExcludeTraits: string array to exclude individual trait, i.e., compilerArguments.
  - copilotcppMsvcCompilerArgumentFilter: regex string to match compiler arguments for GCC.
  - copilotcppClangCompilerArgumentFilter: regex string to match compiler arguments for Clang.
  - copilotcppGccCompilerArgumentFilter: regex string to match compiler arguments for MSVC.
  • Loading branch information
kuchungmsft committed Nov 12, 2024
1 parent c9cae0b commit cf72c95
Show file tree
Hide file tree
Showing 7 changed files with 494 additions and 88 deletions.
25 changes: 25 additions & 0 deletions Extension/src/LanguageServer/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,22 @@ export interface ChatContextResult {
targetArchitecture: string;
}

export interface FileContextResult {
compilerArgs: string[];
compilerUserDefinesRelavant: string[];
compilerOriginalUserDefineCount: number;
macroReferenceCount: number;
}

export interface ProjectContextResult {
language: string;
standardVersion: string;
compiler: string;
targetPlatform: string;
targetArchitecture: string;
fileContext: FileContextResult;
}

// Requests
const PreInitializationRequest: RequestType<void, string, void> = new RequestType<void, string, void>('cpptools/preinitialize');
const InitializationRequest: RequestType<CppInitializationParams, void, void> = new RequestType<CppInitializationParams, void, void>('cpptools/initialize');
Expand All @@ -561,6 +577,7 @@ const GenerateDoxygenCommentRequest: RequestType<GenerateDoxygenCommentParams, G
const ChangeCppPropertiesRequest: RequestType<CppPropertiesParams, void, void> = new RequestType<CppPropertiesParams, void, void>('cpptools/didChangeCppProperties');
const IncludesRequest: RequestType<GetIncludesParams, GetIncludesResult, void> = new RequestType<GetIncludesParams, GetIncludesResult, void>('cpptools/getIncludes');
const CppContextRequest: RequestType<void, ChatContextResult, void> = new RequestType<void, ChatContextResult, void>('cpptools/getChatContext');
const ProjectContextRequest: RequestType<void, ProjectContextResult, void> = new RequestType<void, ProjectContextResult, void>('cpptools/getProjectContext');

// Notifications to the server
const DidOpenNotification: NotificationType<DidOpenTextDocumentParams> = new NotificationType<DidOpenTextDocumentParams>('textDocument/didOpen');
Expand Down Expand Up @@ -792,6 +809,7 @@ export interface Client {
addTrustedCompiler(path: string): Promise<void>;
getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult>;
getChatContext(token: vscode.CancellationToken): Promise<ChatContextResult>;
getProjectContext(token: vscode.CancellationToken): Promise<ProjectContextResult>;
}

export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client {
Expand Down Expand Up @@ -2220,6 +2238,12 @@ export class DefaultClient implements Client {
() => this.languageClient.sendRequest(CppContextRequest, null, token), token);
}

public async getProjectContext(token: vscode.CancellationToken): Promise<ProjectContextResult> {
await withCancellation(this.ready, token);
return DefaultClient.withLspCancellationHandling(
() => this.languageClient.sendRequest(ProjectContextRequest, null, token), token);
}

/**
* a Promise that can be awaited to know when it's ok to proceed.
*
Expand Down Expand Up @@ -4123,4 +4147,5 @@ class NullClient implements Client {
addTrustedCompiler(path: string): Promise<void> { return Promise.resolve(); }
getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult> { return Promise.resolve({} as GetIncludesResult); }
getChatContext(token: vscode.CancellationToken): Promise<ChatContextResult> { return Promise.resolve({} as ChatContextResult); }
getProjectContext(token: vscode.CancellationToken): Promise<ProjectContextResult> { return Promise.resolve({} as ProjectContextResult); }
}
23 changes: 15 additions & 8 deletions Extension/src/LanguageServer/copilotProviders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import * as vscode from 'vscode';
import * as util from '../common';
import { ChatContextResult, GetIncludesResult } from './client';
import { GetIncludesResult } from './client';
import { getActiveClient } from './extension';
import { getProjectContext } from './lmTool';

export interface CopilotTrait {
name: string;
Expand Down Expand Up @@ -38,19 +39,25 @@ export async function registerRelatedFilesProvider(): Promise<void> {

const getIncludesHandler = async () => (await getIncludesWithCancellation(1, token))?.includedFiles.map(file => vscode.Uri.file(file)) ?? [];
const getTraitsHandler = async () => {
const chatContext: ChatContextResult | undefined = await (getActiveClient().getChatContext(token) ?? undefined);
const cppContext = await getProjectContext(context, token);

if (!chatContext) {
if (!cppContext) {
return undefined;
}

let traits: CopilotTrait[] = [
{ name: "language", value: chatContext.language, includeInPrompt: true, promptTextOverride: `The language is ${chatContext.language}.` },
{ name: "compiler", value: chatContext.compiler, includeInPrompt: true, promptTextOverride: `This project compiles using ${chatContext.compiler}.` },
{ name: "standardVersion", value: chatContext.standardVersion, includeInPrompt: true, promptTextOverride: `This project uses the ${chatContext.standardVersion} language standard.` },
{ name: "targetPlatform", value: chatContext.targetPlatform, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetPlatform}.` },
{ name: "targetArchitecture", value: chatContext.targetArchitecture, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetArchitecture}.` }
{ name: "language", value: cppContext.language, includeInPrompt: true, promptTextOverride: `The language is ${cppContext.language}.` },
{ name: "compiler", value: cppContext.compiler, includeInPrompt: true, promptTextOverride: `This project compiles using ${cppContext.compiler}.` },
{ name: "standardVersion", value: cppContext.standardVersion, includeInPrompt: true, promptTextOverride: `This project uses the ${cppContext.standardVersion} language standard.` },
{ name: "targetPlatform", value: cppContext.targetPlatform, includeInPrompt: true, promptTextOverride: `This build targets ${cppContext.targetPlatform}.` },
{ name: "targetArchitecture", value: cppContext.targetArchitecture, includeInPrompt: true, promptTextOverride: `This build targets ${cppContext.targetArchitecture}.` }
];
if (cppContext.compilerArguments.length > 0) {
traits.push({ name: "compilerArguments", value: cppContext.compilerArguments, includeInPrompt: true, promptTextOverride: `The compiler command line arguments may contain: ${cppContext.compilerArguments}.` });
}
if (cppContext.compilerUserDefinesRelavant.length > 0) {
traits.push({ name: "compilerUserDefines", value: cppContext.compilerUserDefinesRelavant, includeInPrompt: true, promptTextOverride: `These compiler command line defines may be relavant: ${cppContext.compilerUserDefinesRelavant}.` });
}

const excludeTraits = context.flags.copilotcppExcludeTraits as string[] ?? [];
traits = traits.filter(trait => !excludeTraits.includes(trait.name));
Expand Down
118 changes: 106 additions & 12 deletions Extension/src/LanguageServer/lmTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@ import { localize } from 'vscode-nls';
import * as util from '../common';
import * as logger from '../logger';
import * as telemetry from '../telemetry';
import { ChatContextResult } from './client';
import { ChatContextResult, ProjectContextResult } from './client';
import { getClients } from './extension';
import { checkTime } from './utils';

const MSVC: string = 'MSVC';
const Clang: string = 'Clang';
const GCC: string = 'GCC';
const knownValues: { [Property in keyof ChatContextResult]?: { [id: string]: string } } = {
language: {
'c': 'C',
'cpp': 'C++',
'cuda-cpp': 'CUDA C++'
},
compiler: {
'msvc': 'MSVC',
'clang': 'Clang',
'gcc': 'GCC'
'msvc': MSVC,
'clang': Clang,
'gcc': GCC
},
standardVersion: {
'c++98': 'C++98',
Expand All @@ -44,6 +48,102 @@ const knownValues: { [Property in keyof ChatContextResult]?: { [id: string]: str
}
};

function formatChatContext(context: ChatContextResult | ProjectContextResult): void {
type KnownKeys = 'language' | 'standardVersion' | 'compiler' | 'targetPlatform' | 'targetArchitecture';
for (const key in knownValues) {
const knownKey = key as KnownKeys

Check failure on line 54 in Extension/src/LanguageServer/lmTool.ts

View workflow job for this annotation

GitHub Actions / job / build

Missing semicolon

Check failure on line 54 in Extension/src/LanguageServer/lmTool.ts

View workflow job for this annotation

GitHub Actions / job / build

Missing semicolon

Check failure on line 54 in Extension/src/LanguageServer/lmTool.ts

View workflow job for this annotation

GitHub Actions / job / build

Missing semicolon
if (knownValues[knownKey] && context[knownKey]) {
// Clear the value if it's not in the known values.
context[knownKey] = knownValues[knownKey][context[knownKey]] || "";
}
}
}

export interface ProjectContext {
language: string;
standardVersion: string;
compiler: string;
targetPlatform: string;
targetArchitecture: string;
compilerArguments: string;
compilerUserDefinesRelavant: string;
}

const matchNothingRegex = /(?!.*)/;

// To be updated after A/B experiments, match nothing for now.
const defaultCompilerArgumentFilters: { [id: string]: RegExp } = {
MSVC: matchNothingRegex,
Clang: matchNothingRegex,
GCC: matchNothingRegex
};

function filterComplierArguments(compiler: string, compilerArguments: string[], context: { flags: Record<string, unknown> }): string[] {
const defaultFilter = defaultCompilerArgumentFilters[compiler] ?? matchNothingRegex;
let additionalFilter: RegExp | undefined;
switch (compiler) {
case MSVC:
additionalFilter = context.flags.copilotcppMsvcCompilerArgumentFilter ? new RegExp(context.flags.copilotcppMsvcCompilerArgumentFilter as string) : undefined;
break;
case Clang:
additionalFilter = context.flags.copilotcppClangCompilerArgumentFilter ? new RegExp(context.flags.copilotcppClangCompilerArgumentFilter as string) : undefined;
break;
case GCC:
additionalFilter = context.flags.copilotcppGccCompilerArgumentFilter ? new RegExp(context.flags.copilotcppGccCompilerArgumentFilter as string) : undefined;
break;
}

return compilerArguments.filter(arg => defaultFilter.test(arg) || additionalFilter?.test(arg));
}

export async function getProjectContext(context: { flags: Record<string, unknown> }, token: vscode.CancellationToken): Promise<ProjectContext | undefined> {
const telemetryProperties: Record<string, string> = {};
try {
const projectContext = await checkTime<ProjectContextResult | undefined>(async () => await getClients()?.ActiveClient?.getProjectContext(token) ?? undefined);
telemetryProperties["time"] = projectContext.time.toString();
if (!projectContext.result) {
return undefined;
}

formatChatContext(projectContext.result);

const filteredcompilerArguments = filterComplierArguments(projectContext.result.compiler, projectContext.result.fileContext.compilerArgs, context);

telemetryProperties["language"] = projectContext.result.language;
telemetryProperties["compiler"] = projectContext.result.compiler;
telemetryProperties["standardVersion"] = projectContext.result.standardVersion;
telemetryProperties["targetPlatform"] = projectContext.result.targetPlatform;
telemetryProperties["targetArchitecture"] = projectContext.result.targetArchitecture;
telemetryProperties["compilerArgumentCount"] = projectContext.result.fileContext.compilerArgs.length.toString();
telemetryProperties["filteredCompilerArgumentCount"] = filteredcompilerArguments.length.toString();
telemetryProperties["compilerUserDefinesRelavantCount"] = projectContext.result.fileContext.compilerUserDefinesRelavant.length.toString();
telemetryProperties["targetArcompilerOriginalUserDefineCounthitecture"] = projectContext.result.fileContext.compilerOriginalUserDefineCount.toString();
telemetryProperties["targetArchmacroReferenceCountitecture"] = projectContext.result.fileContext.macroReferenceCount.toString();

return {
language: projectContext.result.language,
standardVersion: projectContext.result.standardVersion,
compiler: projectContext.result.compiler,
targetPlatform: projectContext.result.targetPlatform,
targetArchitecture: projectContext.result.targetArchitecture,
compilerArguments: (filteredcompilerArguments.length > 0) ? filteredcompilerArguments.join(' ') : '',
compilerUserDefinesRelavant: (projectContext.result.fileContext.compilerUserDefinesRelavant.length > 0) ? projectContext.result.fileContext.compilerUserDefinesRelavant.join(', ') : ''
};
}
catch {
try {
logger.getOutputChannelLogger().appendLine(localize("copilot.cppcontext.error", "Error while retrieving the project context."));
}
catch {
// Intentionally swallow any exception.
}
telemetryProperties["error"] = "true";
return undefined;
} finally {
telemetry.logLanguageModelToolEvent('Completions/tool', telemetryProperties);
}
}

export class CppConfigurationLanguageModelTool implements vscode.LanguageModelTool<void> {
public async invoke(options: vscode.LanguageModelToolInvocationOptions<void>, token: vscode.CancellationToken): Promise<vscode.LanguageModelToolResult> {
return new vscode.LanguageModelToolResult([
Expand All @@ -63,13 +163,7 @@ export class CppConfigurationLanguageModelTool implements vscode.LanguageModelTo
return 'No configuration information is available for the active document.';
}

for (const key in knownValues) {
const knownKey = key as keyof ChatContextResult;
if (knownValues[knownKey] && chatContext[knownKey]) {
// Clear the value if it's not in the known values.
chatContext[knownKey] = knownValues[knownKey][chatContext[knownKey]] || "";
}
}
formatChatContext(chatContext);

let contextString = "";
if (chatContext.language) {
Expand Down Expand Up @@ -100,7 +194,7 @@ export class CppConfigurationLanguageModelTool implements vscode.LanguageModelTo
telemetryProperties["error"] = "true";
return "";
} finally {
telemetry.logLanguageModelToolEvent('cpp', telemetryProperties);
telemetry.logLanguageModelToolEvent('Chat/Tool/cpp', telemetryProperties);
}
}

Expand Down
6 changes: 6 additions & 0 deletions Extension/src/LanguageServer/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,9 @@ export async function withCancellation<T>(promise: Promise<T>, token: vscode.Can
});
});
}

export async function checkTime<T>(fn: () => Promise<T>): Promise<{ result: T, time: number }> {

Check failure on line 116 in Extension/src/LanguageServer/utils.ts

View workflow job for this annotation

GitHub Actions / job / build

Expected a semicolon

Check failure on line 116 in Extension/src/LanguageServer/utils.ts

View workflow job for this annotation

GitHub Actions / job / build

Expected a semicolon

Check failure on line 116 in Extension/src/LanguageServer/utils.ts

View workflow job for this annotation

GitHub Actions / job / build

Expected a semicolon
const start = Date.now();
const result = await fn();
return { result, time: Date.now() - start };
}
2 changes: 1 addition & 1 deletion Extension/src/telemetry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ export function logLanguageServerEvent(eventName: string, properties?: Record<st
export function logLanguageModelToolEvent(eventName: string, properties?: Record<string, string>, metrics?: Record<string, number>): void {
const sendTelemetry = () => {
if (experimentationTelemetry) {
const eventNamePrefix: string = "C_Cpp/Copilot/Chat/Tool/";
const eventNamePrefix: string = "C_Cpp/Copilot/";
experimentationTelemetry.sendTelemetryEvent(eventNamePrefix + eventName, properties, metrics);
}
};
Expand Down
Loading

0 comments on commit cf72c95

Please sign in to comment.