Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_portable_linux_pytorch_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ jobs:
--index-url "${{ inputs.cloudfront_url }}/${{ inputs.amdgpu_family }}/" \
--clean \
--output-dir ${{ env.PACKAGE_DIST_DIR }} ${{ env.optional_build_prod_arguments }}
echo "torch_version=`cd ${{ env.PACKAGE_DIST_DIR }}; python -c 'import glob; print(glob.glob("torch-*.whl")[0].split("-")[1])'`" >> $GITHUB_OUTPUT
python ./build_tools/github_actions/write_torch_version.py

- name: Configure AWS Credentials
if: always()
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/build_windows_pytorch_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ jobs:
PACKAGE_DIST_DIR: ${{ github.workspace }}\output\packages\dist
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
optional_build_prod_arguments: ""
outputs:
torch_version: ${{ steps.build-pytorch-wheels.outputs.torch_version }}
defaults:
run:
# Note: there are mixed uses of 'bash' (this default) and 'cmd' below
Expand Down Expand Up @@ -112,6 +114,7 @@ jobs:
--rocm-version ${{ inputs.rocm_version }}

- name: Build PyTorch Wheels
id: build-pytorch-wheels
# Using 'cmd' here is load bearing! There are configuration issues when
# run under 'bash': https://github.com/ROCm/TheRock/issues/827#issuecomment-3025858800
shell: cmd
Expand All @@ -126,6 +129,7 @@ jobs:
--clean ^
--output-dir ${{ env.PACKAGE_DIST_DIR }} ^
${{ env.optional_build_prod_arguments }}
python ./build_tools/github_actions/write_torch_version.py

- name: Configure AWS Credentials
if: always()
Expand Down
19 changes: 19 additions & 0 deletions build_tools/github_actions/write_torch_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3

"""Writes torch_version to GITHUB_OUTPUT."""

import os
import glob

from github_actions_utils import *


def main(argv: list[str]):
# Get the torch version from the first torch wheel in PACKAGE_DIST_DIR.
package_dist_dir = os.getenv("PACKAGE_DIST_DIR")
version = glob.glob("torch-*.whl", root_dir=package_dist_dir)[0].split("-")[1]
gha_set_output({"torch_version": version})


if __name__ == "__main__":
main(sys.argv[1:])
Loading