diff --git a/.github/workflows/build_portable_linux_pytorch_wheels.yml b/.github/workflows/build_portable_linux_pytorch_wheels.yml index 8f7906a99..6d5187987 100644 --- a/.github/workflows/build_portable_linux_pytorch_wheels.yml +++ b/.github/workflows/build_portable_linux_pytorch_wheels.yml @@ -116,7 +116,7 @@ jobs: --index-url "${{ inputs.cloudfront_url }}/${{ inputs.amdgpu_family }}/" \ --clean \ --output-dir ${{ env.PACKAGE_DIST_DIR }} ${{ env.optional_build_prod_arguments }} - echo "torch_version=`cd ${{ env.PACKAGE_DIST_DIR }}; python -c 'import glob; print(glob.glob("torch-*.whl")[0].split("-")[1])'`" >> $GITHUB_OUTPUT + python ./build_tools/github_actions/write_torch_version.py - name: Configure AWS Credentials if: always() diff --git a/.github/workflows/build_windows_pytorch_wheels.yml b/.github/workflows/build_windows_pytorch_wheels.yml index 3616e3de7..110a89e28 100644 --- a/.github/workflows/build_windows_pytorch_wheels.yml +++ b/.github/workflows/build_windows_pytorch_wheels.yml @@ -64,6 +64,8 @@ jobs: PACKAGE_DIST_DIR: ${{ github.workspace }}\output\packages\dist S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" optional_build_prod_arguments: "" + outputs: + torch_version: ${{ steps.build-pytorch-wheels.outputs.torch_version }} defaults: run: # Note: there are mixed uses of 'bash' (this default) and 'cmd' below @@ -112,6 +114,7 @@ jobs: --rocm-version ${{ inputs.rocm_version }} - name: Build PyTorch Wheels + id: build-pytorch-wheels # Using 'cmd' here is load bearing! There are configuration issues when # run under 'bash': https://github.com/ROCm/TheRock/issues/827#issuecomment-3025858800 shell: cmd @@ -126,6 +129,7 @@ jobs: --clean ^ --output-dir ${{ env.PACKAGE_DIST_DIR }} ^ ${{ env.optional_build_prod_arguments }} + python ./build_tools/github_actions/write_torch_version.py - name: Configure AWS Credentials if: always() diff --git a/build_tools/github_actions/write_torch_version.py b/build_tools/github_actions/write_torch_version.py new file mode 100755 index 000000000..111e58438 --- /dev/null +++ b/build_tools/github_actions/write_torch_version.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +"""Writes torch_version to GITHUB_OUTPUT.""" + +import os +import glob + +from github_actions_utils import * + + +def main(argv: list[str]): + # Get the torch version from the first torch wheel in PACKAGE_DIST_DIR. + package_dist_dir = os.getenv("PACKAGE_DIST_DIR") + version = glob.glob("torch-*.whl", root_dir=package_dist_dir)[0].split("-")[1] + gha_set_output({"torch_version": version}) + + +if __name__ == "__main__": + main(sys.argv[1:])