Skip to content

Commit

Permalink
Dockerfile torch fix (#987)
Browse files Browse the repository at this point in the history
* add torch to requirements.txt at build time to force version to stick

* fix xformers check

* better handling of xformers based on installed torch version

* fix for ci w/o torch
  • Loading branch information
winglian authored Dec 21, 2023
1 parent d25c34c commit 161bcb6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
pytorch: 2.1.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps:
- name: Checkout
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
pytorch: 2.1.1
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
steps:
Expand Down Expand Up @@ -80,7 +80,7 @@ jobs:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
pytorch: 2.1.1
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
steps:
Expand Down
1 change: 0 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
else \
Expand Down
15 changes: 9 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""setup.py for axolotl"""

from importlib.metadata import PackageNotFoundError, version

from setuptools import find_packages, setup


Expand All @@ -22,12 +24,13 @@ def parse_requirements():
# Handle standard packages
_install_requires.append(line)

# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
try:
torch_version = version("torch")
if torch_version.startswith("2.1.1"):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers==0.0.23")
except PackageNotFoundError:
pass

return _install_requires, _dependency_links

Expand Down

0 comments on commit 161bcb6

Please sign in to comment.