Skip to content

Fix pytorch memory curve estimation for rmm backed allocator #347

Fix pytorch memory curve estimation for rmm backed allocator

Fix pytorch memory curve estimation for rmm backed allocator #347

Workflow file for this run

name: "cf.backends"
on:
workflow_dispatch:
push:
branches: [ main ]
tags:
- v*
pull_request:
branches: [ main ]
jobs:
pytorch:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.8]
os: [ubuntu-latest]
torch-version: ["~=1.11.0", "~=1.12.0", "~=1.13.0"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: |
**/setup.cfg
requirements/*.txt
- name: Install dependencies
run: |
python -m pip install "torch${{ matrix.torch-version }}"
python -m pip install .[pytorch-dev]
- name: Build
run: |
python setup.py develop
- name: Run Pytorch tests
run: |
pytest --cov="crossfit/array" -m "pytorch and not (singlegpu or multigpu)"
# jax:
# runs-on: ${{ matrix.os }}
# strategy:
# matrix:
# python-version: [3.8]
# os: [ubuntu-latest]
# jax-version: ["~=0.3.0", "~=0.4.0"]
# steps:
# - uses: actions/checkout@v3
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# cache: 'pip'
# cache-dependency-path: |
# **/setup.cfg
# requirements/*.txt
# - name: Install dependencies
# run: |
# python -m pip install "jax[cpu]${{ matrix.jax-version }}"
# python -m pip install .[jax-dev]
# - name: Build
# run: |
# python setup.py develop
# - name: Run Jax tests
# run: |
# pytest --cov="crossfit/array" -m "jax"