diff --git a/.env.example b/.env.example index ae80d08..7197a89 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,16 @@ +# Build configuration +TORCH_CUDA_ARCH_LIST=80;90 +CUDA_HOME=/usr/local/cuda +CUDA_CACHE_PATH=${HOME}/.cache/stream_attn/cuda +TRITON_CACHE_DIR=${HOME}/.cache/stream_attn/triton + +# StreamAttention environment configuration example +# Build configuration +TORCH_CUDA_ARCH_LIST=80;90 +CUDA_HOME=/usr/local/cuda +CUDA_CACHE_PATH=${HOME}/.cache/stream_attn/cuda +TRITON_CACHE_DIR=${HOME}/.cache/stream_attn/triton + # StreamAttention environment configuration example # Copy to .env and edit as needed @@ -18,4 +31,4 @@ STREAM_ATTENTION_RING_OVERLAP_SIZE=256 # Star Attention STREAM_ATTENTION_STAR_BLOCK_SIZE=2048 STREAM_ATTENTION_STAR_ANCHOR_SIZE=256 -STREAM_ATTENTION_STAR_NUM_HOSTS=1 \ No newline at end of file +STREAM_ATTENTION_STAR_NUM_HOSTS=1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..99d7073 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,28 @@ +name: Build Wheels + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + sm: [80, 90] + env: + TORCH_CUDA_ARCH_LIST: ${{ matrix.sm }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install build tooling + run: | + python -m pip install --upgrade pip + pip install build + - name: Build wheel + run: | + python -m build --wheel diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..afee971 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = [ + "setuptools>=61", + "wheel", +] +build-backend = "setuptools.build_meta" + +[project] +name = "stream-attention" +version = "1.0.0" +description = "Production-ready multi-GPU FlashAttention implementation with support for extremely long contexts" +authors = [{name = "StreamAttention Team", email = "streamattention@example.com"}] +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">=3.8" +dependencies = [ + "torch==2.1.0", + "triton==2.1.0", + "pyyaml>=6.0", +] + +[tool.stream_attention] +cuda_version = "12.1" +triton_version = "2.1.0" diff --git a/stream_attention/benchmarks/accuracy_test.py b/stream_attention/benchmarks/accuracy_test.py index 5eb3522..92b774b 100644 --- a/stream_attention/benchmarks/accuracy_test.py +++ b/stream_attention/benchmarks/accuracy_test.py @@ -59,3 +59,4 @@ def main(): if __name__ == "__main__": main() +