Skip to content

Commit

Permalink
Support A/B Compiler Arguments Traits to Completions Prompt
Browse files Browse the repository at this point in the history
- Depends on cpptools' update to provide ProjectContextResult.
- Added two new traits
  - compilerArguments: a list of compiler command arguments that could affect Copilot generating completions.
  - compilerUserDefines: a list of compiler command defines that could affect Copilot generating completions. Macro references are used to exclude the ones that are not relavent.
- A/B Experimental flags
  - copilotcppTraits: boolean flag to enable cpp traits
  - copilotcppExcludeTraits: string array to exclude individual trait, i.e., compilerArguments.
  - copilotcppMsvcCompilerArgumentFilter: regex string to match compiler arguments for GCC.
  - copilotcppClangCompilerArgumentFilter: regex string to match compiler arguments for Clang.
  - copilotcppGccCompilerArgumentFilter: regex string to match compiler arguments for MSVC.
  - copilotcppCompilerArgumentDirectAskMap: a stringify map string to map arguments to direct ask statements.
  - copilotcppMacroReferenceFilter: regex string to filter macro references for telemetry.
  • Loading branch information
kuchungmsft committed Nov 19, 2024
1 parent c9cae0b commit caa3897
Show file tree
Hide file tree
Showing 7 changed files with 725 additions and 90 deletions.
24 changes: 24 additions & 0 deletions Extension/src/LanguageServer/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,21 @@ export interface ChatContextResult {
targetArchitecture: string;
}

export interface FileContextResult {
compilerArguments: string[];
compilerUserDefines: string[];
macroReferences: string[];
}

export interface ProjectContextResult {
language: string;
standardVersion: string;
compiler: string;
targetPlatform: string;
targetArchitecture: string;
fileContext: FileContextResult;
}

// Requests
const PreInitializationRequest: RequestType<void, string, void> = new RequestType<void, string, void>('cpptools/preinitialize');
const InitializationRequest: RequestType<CppInitializationParams, void, void> = new RequestType<CppInitializationParams, void, void>('cpptools/initialize');
Expand All @@ -561,6 +576,7 @@ const GenerateDoxygenCommentRequest: RequestType<GenerateDoxygenCommentParams, G
const ChangeCppPropertiesRequest: RequestType<CppPropertiesParams, void, void> = new RequestType<CppPropertiesParams, void, void>('cpptools/didChangeCppProperties');
const IncludesRequest: RequestType<GetIncludesParams, GetIncludesResult, void> = new RequestType<GetIncludesParams, GetIncludesResult, void>('cpptools/getIncludes');
const CppContextRequest: RequestType<void, ChatContextResult, void> = new RequestType<void, ChatContextResult, void>('cpptools/getChatContext');
const ProjectContextRequest: RequestType<void, ProjectContextResult, void> = new RequestType<void, ProjectContextResult, void>('cpptools/getProjectContext');

// Notifications to the server
const DidOpenNotification: NotificationType<DidOpenTextDocumentParams> = new NotificationType<DidOpenTextDocumentParams>('textDocument/didOpen');
Expand Down Expand Up @@ -792,6 +808,7 @@ export interface Client {
addTrustedCompiler(path: string): Promise<void>;
getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult>;
getChatContext(token: vscode.CancellationToken): Promise<ChatContextResult>;
getProjectContext(token: vscode.CancellationToken): Promise<ProjectContextResult>;
}

export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client {
Expand Down Expand Up @@ -2220,6 +2237,12 @@ export class DefaultClient implements Client {
() => this.languageClient.sendRequest(CppContextRequest, null, token), token);
}

public async getProjectContext(token: vscode.CancellationToken): Promise<ProjectContextResult> {
await withCancellation(this.ready, token);
return DefaultClient.withLspCancellationHandling(
() => this.languageClient.sendRequest(ProjectContextRequest, null, token), token);
}

/**
* a Promise that can be awaited to know when it's ok to proceed.
*
Expand Down Expand Up @@ -4123,4 +4146,5 @@ class NullClient implements Client {
addTrustedCompiler(path: string): Promise<void> { return Promise.resolve(); }
getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult> { return Promise.resolve({} as GetIncludesResult); }
getChatContext(token: vscode.CancellationToken): Promise<ChatContextResult> { return Promise.resolve({} as ChatContextResult); }
getProjectContext(token: vscode.CancellationToken): Promise<ProjectContextResult> { return Promise.resolve({} as ProjectContextResult); }
}
52 changes: 44 additions & 8 deletions Extension/src/LanguageServer/copilotProviders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import * as vscode from 'vscode';
import * as util from '../common';
import { ChatContextResult, GetIncludesResult } from './client';
import { GetIncludesResult } from './client';
import { getActiveClient } from './extension';
import { getProjectContext } from './lmTool';

export interface CopilotTrait {
name: string;
Expand Down Expand Up @@ -38,19 +39,54 @@ export async function registerRelatedFilesProvider(): Promise<void> {

const getIncludesHandler = async () => (await getIncludesWithCancellation(1, token))?.includedFiles.map(file => vscode.Uri.file(file)) ?? [];
const getTraitsHandler = async () => {
const chatContext: ChatContextResult | undefined = await (getActiveClient().getChatContext(token) ?? undefined);
const cppContext = await getProjectContext(context, token);

if (!chatContext) {
if (!cppContext) {
return undefined;
}

let traits: CopilotTrait[] = [
{ name: "language", value: chatContext.language, includeInPrompt: true, promptTextOverride: `The language is ${chatContext.language}.` },
{ name: "compiler", value: chatContext.compiler, includeInPrompt: true, promptTextOverride: `This project compiles using ${chatContext.compiler}.` },
{ name: "standardVersion", value: chatContext.standardVersion, includeInPrompt: true, promptTextOverride: `This project uses the ${chatContext.standardVersion} language standard.` },
{ name: "targetPlatform", value: chatContext.targetPlatform, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetPlatform}.` },
{ name: "targetArchitecture", value: chatContext.targetArchitecture, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetArchitecture}.` }
{ name: "intellisense", value: 'intellisense', includeInPrompt: true, promptTextOverride: `IntelliSense is currently configured with the following compiler information. It's best effort to reflect the active configuration, and the project may have more configurations targeting different platforms.` },
{ name: "intellisenseBegin", value: 'Begin', includeInPrompt: true, promptTextOverride: `Begin of IntelliSense information.` }
];
if (cppContext.language) {
traits.push({ name: "language", value: cppContext.language, includeInPrompt: true, promptTextOverride: `The language is ${cppContext.language}.` });
}
if (cppContext.compiler) {
traits.push({ name: "compiler", value: cppContext.compiler, includeInPrompt: true, promptTextOverride: `This project compiles using ${cppContext.compiler}.` });
}
if (cppContext.standardVersion) {
traits.push({ name: "standardVersion", value: cppContext.standardVersion, includeInPrompt: true, promptTextOverride: `This project uses the ${cppContext.standardVersion} language standard.` });
}
if (cppContext.targetPlatform) {
traits.push({ name: "targetPlatform", value: cppContext.targetPlatform, includeInPrompt: true, promptTextOverride: `This build targets ${cppContext.targetPlatform}.` });
}
if (cppContext.targetArchitecture) {
traits.push({ name: "targetArchitecture", value: cppContext.targetArchitecture, includeInPrompt: true, promptTextOverride: `This build targets ${cppContext.targetArchitecture}.` });
}
let directAsks: string = '';
if (cppContext.compilerArguments.length > 0) {
const directAskMap: { [key: string]: string } = JSON.parse(context.flags.copilotcppCompilerArgumentDirectAskMap as string ?? '{}');
const updatedArguments = cppContext.compilerArguments.filter(arg => {
if (directAskMap[arg]) {
directAsks += `${directAskMap[arg]} `;
return false;
}
return true;
});

const compilerArgumentsValue = updatedArguments.join(", ");
traits.push({ name: "compilerArguments", value: compilerArgumentsValue, includeInPrompt: true, promptTextOverride: `The compiler arguments include: ${compilerArgumentsValue}.` });
}
if (cppContext.compilerUserDefinesRelevant.length > 0) {
const compilerUserDefinesValue = cppContext.compilerUserDefinesRelevant.join(", ");
traits.push({ name: "compilerUserDefines", value: compilerUserDefinesValue, includeInPrompt: true, promptTextOverride: `These compiler command line user defines may be relevent: ${compilerUserDefinesValue}.` });
}
if (directAsks) {
traits.push({ name: "directAsks", value: directAsks, includeInPrompt: true, promptTextOverride: directAsks });
}

traits.push({ name: "intellisenseEnd", value: 'End', includeInPrompt: true, promptTextOverride: `End of IntelliSense information.` });

const excludeTraits = context.flags.copilotcppExcludeTraits as string[] ?? [];
traits = traits.filter(trait => !excludeTraits.includes(trait.name));
Expand Down
157 changes: 145 additions & 12 deletions Extension/src/LanguageServer/lmTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@ import { localize } from 'vscode-nls';
import * as util from '../common';
import * as logger from '../logger';
import * as telemetry from '../telemetry';
import { ChatContextResult } from './client';
import { ChatContextResult, ProjectContextResult } from './client';
import { getClients } from './extension';
import { checkTime } from './utils';

const MSVC: string = 'MSVC';
const Clang: string = 'Clang';
const GCC: string = 'GCC';
const knownValues: { [Property in keyof ChatContextResult]?: { [id: string]: string } } = {
language: {
'c': 'C',
'cpp': 'C++',
'cuda-cpp': 'CUDA C++'
},
compiler: {
'msvc': 'MSVC',
'clang': 'Clang',
'gcc': 'GCC'
'msvc': MSVC,
'clang': Clang,
'gcc': GCC
},
standardVersion: {
'c++98': 'C++98',
Expand All @@ -44,6 +48,141 @@ const knownValues: { [Property in keyof ChatContextResult]?: { [id: string]: str
}
};

function formatChatContext(context: ChatContextResult | ProjectContextResult): void {
type KnownKeys = 'language' | 'standardVersion' | 'compiler' | 'targetPlatform';
for (const key in knownValues) {
const knownKey = key as KnownKeys;
if (knownValues[knownKey] && context[knownKey]) {
// Clear the value if it's not in the known values.
context[knownKey] = knownValues[knownKey][context[knownKey]] || "";
}
}
}

export interface ProjectContext {
language: string;
standardVersion: string;
compiler: string;
targetPlatform: string;
targetArchitecture: string;
compilerArguments: string[];
compilerUserDefinesRelevant: string[];
}

// Set these values for local testing purpose without involving control tower.
const defaultCompilerArgumentFilters: { [id: string]: RegExp | undefined } = {
MSVC: undefined, // Example: /^(\/std:.*|\/EHs-c-|\/GR-|\/await.*)$/,
Clang: undefined,
GCC: undefined // Example: /^(-std=.*|-fno-rtti|-fno-exceptions)$/
};

function filterComplierArguments(compiler: string, compilerArguments: string[], context: { flags: Record<string, unknown> }): string[] {
const defaultFilter: RegExp | undefined = defaultCompilerArgumentFilters[compiler] ?? undefined;
let additionalFilter: RegExp | undefined;
switch (compiler) {
case MSVC:
additionalFilter = context.flags.copilotcppMsvcCompilerArgumentFilter ? new RegExp(context.flags.copilotcppMsvcCompilerArgumentFilter as string) : undefined;
break;
case Clang:
additionalFilter = context.flags.copilotcppClangCompilerArgumentFilter ? new RegExp(context.flags.copilotcppClangCompilerArgumentFilter as string) : undefined;
break;
case GCC:
additionalFilter = context.flags.copilotcppGccCompilerArgumentFilter ? new RegExp(context.flags.copilotcppGccCompilerArgumentFilter as string) : undefined;
break;
}

return compilerArguments.filter(arg => defaultFilter?.test(arg) || additionalFilter?.test(arg));
}

// We can set up copilotcppMacroReferenceFilter feature flag to filter macro references to learn about
// macro usage distribution, i.e., compiler or platform specific macros, or even the presence of certain macros.
const defaultMacroReferenceFilter: RegExp | undefined = undefined;
function filterMacroReferences(macroReferences: string[], context: { flags: Record<string, unknown> }): string[] {
const filter: RegExp | undefined = context.flags.copilotcppMacroReferenceFilter ? new RegExp(context.flags.copilotcppMacroReferenceFilter as string) : undefined;

return macroReferences.filter(macro => defaultMacroReferenceFilter?.test(macro) || filter?.test(macro));
}

export async function getProjectContext(context: { flags: Record<string, unknown> }, token: vscode.CancellationToken): Promise<ProjectContext | undefined> {
const telemetryProperties: Record<string, string> = {};
try {
const projectContext = await checkTime<ProjectContextResult | undefined>(async () => await getClients()?.ActiveClient?.getProjectContext(token) ?? undefined);
telemetryProperties["time"] = projectContext.time.toString();
if (!projectContext.result) {
return undefined;
}

formatChatContext(projectContext.result);

const result: ProjectContext = {
language: projectContext.result.language,
standardVersion: projectContext.result.standardVersion,
compiler: projectContext.result.compiler,
targetPlatform: projectContext.result.targetPlatform,
targetArchitecture: projectContext.result.targetArchitecture,
compilerArguments: [],
compilerUserDefinesRelevant: []
};

if (projectContext.result.language) {
telemetryProperties["language"] = projectContext.result.language;
}
if (projectContext.result.compiler) {
telemetryProperties["compiler"] = projectContext.result.compiler;
}
if (projectContext.result.standardVersion) {
telemetryProperties["standardVersion"] = projectContext.result.standardVersion;
}
if (projectContext.result.targetPlatform) {
telemetryProperties["targetPlatform"] = projectContext.result.targetPlatform;
}
if (projectContext.result.targetArchitecture) {
telemetryProperties["targetArchitecture"] = projectContext.result.targetArchitecture;
}
telemetryProperties["compilerArgumentCount"] = projectContext.result.fileContext.compilerArguments.length.toString();
// Telemtry to learn about the argument and macro distribution. The filtered arguments and macro references
// are expected to be non-PII.
if (projectContext.result.fileContext.compilerArguments.length) {
const filteredCompilerArguments = filterComplierArguments(projectContext.result.compiler, projectContext.result.fileContext.compilerArguments, context);
if (filteredCompilerArguments.length > 0) {
telemetryProperties["filteredCompilerArguments"] = filteredCompilerArguments.join(', ');
result.compilerArguments = filteredCompilerArguments;
}
}
telemetryProperties["compilerUserDefinesCount"] = projectContext.result.fileContext.compilerUserDefines.length.toString();
if (projectContext.result.fileContext.compilerUserDefines.length > 0) {
const userDefinesWithoutValue = projectContext.result.fileContext.compilerUserDefines.map(value => value.split('=')[0]);
const userDefinesRelatedToThisFile = userDefinesWithoutValue.filter(value => projectContext.result?.fileContext.macroReferences.includes(value));
if (userDefinesRelatedToThisFile.length > 0) {
// Don't care the actual name of the user define, just the count that's relevant.
telemetryProperties["compilerUserDefinesRelevantCount"] = userDefinesRelatedToThisFile.length.toString();
result.compilerUserDefinesRelevant = userDefinesRelatedToThisFile;
}
}
telemetryProperties["macroReferenceCount"] = projectContext.result.fileContext.macroReferences.length.toString();
if (projectContext.result.fileContext.macroReferences.length > 0) {
const filteredMacroReferences = filterMacroReferences(projectContext.result.fileContext.macroReferences, context);
if (filteredMacroReferences.length > 0) {
telemetryProperties["filteredMacroReferences"] = filteredMacroReferences.join(', ');
}
}

return result;
}
catch {
try {
logger.getOutputChannelLogger().appendLine(localize("copilot.cppcontext.error", "Error while retrieving the project context."));
}
catch {
// Intentionally swallow any exception.
}
telemetryProperties["error"] = "true";
return undefined;
} finally {
telemetry.logLanguageModelToolEvent('Completions/tool', telemetryProperties);
}
}

export class CppConfigurationLanguageModelTool implements vscode.LanguageModelTool<void> {
public async invoke(options: vscode.LanguageModelToolInvocationOptions<void>, token: vscode.CancellationToken): Promise<vscode.LanguageModelToolResult> {
return new vscode.LanguageModelToolResult([
Expand All @@ -63,13 +202,7 @@ export class CppConfigurationLanguageModelTool implements vscode.LanguageModelTo
return 'No configuration information is available for the active document.';
}

for (const key in knownValues) {
const knownKey = key as keyof ChatContextResult;
if (knownValues[knownKey] && chatContext[knownKey]) {
// Clear the value if it's not in the known values.
chatContext[knownKey] = knownValues[knownKey][chatContext[knownKey]] || "";
}
}
formatChatContext(chatContext);

let contextString = "";
if (chatContext.language) {
Expand Down Expand Up @@ -100,7 +233,7 @@ export class CppConfigurationLanguageModelTool implements vscode.LanguageModelTo
telemetryProperties["error"] = "true";
return "";
} finally {
telemetry.logLanguageModelToolEvent('cpp', telemetryProperties);
telemetry.logLanguageModelToolEvent('Chat/Tool/cpp', telemetryProperties);
}
}

Expand Down
6 changes: 6 additions & 0 deletions Extension/src/LanguageServer/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,9 @@ export async function withCancellation<T>(promise: Promise<T>, token: vscode.Can
});
});
}

export async function checkTime<T>(fn: () => Promise<T>): Promise<{ result: T; time: number }> {
const start = Date.now();
const result = await fn();
return { result, time: Date.now() - start };
}
2 changes: 1 addition & 1 deletion Extension/src/telemetry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ export function logLanguageServerEvent(eventName: string, properties?: Record<st
export function logLanguageModelToolEvent(eventName: string, properties?: Record<string, string>, metrics?: Record<string, number>): void {
const sendTelemetry = () => {
if (experimentationTelemetry) {
const eventNamePrefix: string = "C_Cpp/Copilot/Chat/Tool/";
const eventNamePrefix: string = "C_Cpp/Copilot/";
experimentationTelemetry.sendTelemetryEvent(eventNamePrefix + eventName, properties, metrics);
}
};
Expand Down
Loading

0 comments on commit caa3897

Please sign in to comment.