diff --git a/src/plugins/ForwardMerger/forward_merger.ts b/src/plugins/ForwardMerger/forward_merger.ts index 5165133..5fff956 100644 --- a/src/plugins/ForwardMerger/forward_merger.ts +++ b/src/plugins/ForwardMerger/forward_merger.ts @@ -16,7 +16,7 @@ import { OpsBotPlugin } from "../../plugin"; import { PayloadRepository } from "../../types"; -import { isVersionedBranch, getVersionFromBranch } from "../../shared"; +import { isVersionedBranch, getVersionFromBranch, isVersionedUCXBranch } from "../../shared"; import { basename } from "path"; import { Context } from "probot"; import { Octokit } from "@octokit/rest" @@ -40,7 +40,7 @@ export class ForwardMerger extends OpsBotPlugin { async mergeForward(): Promise { if (await this.pluginIsDisabled()) return; - if (!isVersionedBranch(this.currentBranch)) { + if (!isVersionedBranch(this.currentBranch) && !isVersionedUCXBranch(this.currentBranch)) { this.logger.info("Will not merge forward on non-versioned branch"); return; } diff --git a/src/shared.ts b/src/shared.ts index 4ab0c04..8bd533a 100644 --- a/src/shared.ts +++ b/src/shared.ts @@ -51,6 +51,16 @@ export const isVersionedBranch = (branchName: string): boolean => { return Boolean(branchName.match(versionedBranchExp)); }; +/** + * Returns true if the provided string is a versioned branch that follows the ucxx/py versioning scheme + * (i.e. "branch-0.36", "branch-0.40", etc.) + * @param branchName + */ +export const isVersionedUCXBranch = (branchName: string): boolean => { + const regex = "/^branch-\d{1,2}\.\d\d$/"; + return Boolean(branchName.match(regex)); +}; + /** * Returns the RAPIDS version from a versioned branch name */