diff --git a/torchci/lib/bot/verifyDisableTestIssueBot.ts b/torchci/lib/bot/verifyDisableTestIssueBot.ts index 75d8e70b78..7e873c04d9 100644 --- a/torchci/lib/bot/verifyDisableTestIssueBot.ts +++ b/torchci/lib/bot/verifyDisableTestIssueBot.ts @@ -45,12 +45,17 @@ async function getValidationComment( return [0, ""]; } -// Returns the platform labels that are expected, and invalid labels that we do not expect to be there -export function getExpectedPlatformLabels( +// Returns the platform module labels that are expected, and invalid labels that we do not expect to be there +export function getExpectedPlatformModuleLabels( platforms: string[], labels: string[] ): [string[], string[]] { - let supportedPlatformLabels = Array.from(supportedPlatforms.values()).flat(); + let supportedPlatformLabels = Array.from(supportedPlatforms.values()) + .flat() + // Quick hack to make sure oncall: pt2 doesn't get deleted. + // TODO: figure out a better way to differentiate between labels that should + // stay and labels that shouldn't + .filter((label) => label.startsWith("module: ")); let existingPlatformLabels = labels.filter((label) => supportedPlatformLabels.includes(label) ); @@ -302,7 +307,7 @@ export default function verifyDisableTestIssueBot(app: Probot): void { } else { // check labels, add labels as needed let [expectedPlatformLabels, invalidPlatformLabels] = - getExpectedPlatformLabels(platformsToSkip, labels); + getExpectedPlatformModuleLabels(platformsToSkip, labels); let labelsSet = new Set(labels); if (!expectedPlatformLabels.every((label) => labelsSet.has(label))) { await context.octokit.issues.addLabels({ diff --git a/torchci/test/verifyDisableTestIssue.test.ts b/torchci/test/verifyDisableTestIssue.test.ts index d964959b52..1fef7cbe25 100644 --- a/torchci/test/verifyDisableTestIssue.test.ts +++ b/torchci/test/verifyDisableTestIssue.test.ts @@ -325,32 +325,43 @@ describe("verify-disable-test-issue", () => { expect(comment.includes("ERROR")).toBeFalsy(); }); - test("various getExpectedPlatformLabels tests", async () => { - expect(await bot.getExpectedPlatformLabels(["linux"], ["random"])).toEqual([ - [], - [], - ]); + test("various getExpectedPlatformModuleLabels tests", async () => { expect( - await bot.getExpectedPlatformLabels(["inductor"], ["random"]) + await bot.getExpectedPlatformModuleLabels(["linux"], ["random"]) + ).toEqual([[], []]); + expect( + await bot.getExpectedPlatformModuleLabels(["inductor"], ["random"]) ).toEqual([["oncall: pt2"], []]); expect( - await bot.getExpectedPlatformLabels(["linux"], ["random", "module: rocm"]) + await bot.getExpectedPlatformModuleLabels( + ["linux"], + ["random", "module: rocm"] + ) ).toEqual([[], ["module: rocm"]]); expect( - await bot.getExpectedPlatformLabels(["rocm"], ["random", "module: rocm"]) + await bot.getExpectedPlatformModuleLabels( + ["rocm"], + ["random", "module: rocm"] + ) ).toEqual([["module: rocm"], []]); expect( - await bot.getExpectedPlatformLabels( + await bot.getExpectedPlatformModuleLabels( ["dynamo", "inductor"], ["random", "module: rocm"] ) ).toEqual([["oncall: pt2"], ["module: rocm"]]); expect( - await bot.getExpectedPlatformLabels( + await bot.getExpectedPlatformModuleLabels( ["linux", "rocm"], ["random", "module: rocm"] ) ).toEqual([[], ["module: rocm"]]); + expect( + await bot.getExpectedPlatformModuleLabels( + ["linux", "rocm"], + ["random", "module: rocm", "oncall: pt2"] + ) + ).toEqual([[], ["module: rocm"]]); }); });