Skip to content
Open
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ef95183
3rdparty tvm bump
LeiWang1999 Oct 22, 2025
3c175e4
bump tvm into v0.22.0
LeiWang1999 Oct 22, 2025
951f2de
lint fix
LeiWang1999 Oct 22, 2025
6d29c1e
rebase tvm
LeiWang1999 Oct 23, 2025
21e7a0a
Update submodule tvm to latest commit 3085bc4
LeiWang1999 Oct 23, 2025
3877610
Refactor: Update configuration retrieval in CopyNode and adjust test …
LeiWang1999 Oct 23, 2025
cd6daaf
lint fix
LeiWang1999 Oct 23, 2025
995315e
test fix
LeiWang1999 Oct 23, 2025
7ef5d01
add requirement
LeiWang1999 Oct 24, 2025
fdf4669
atomic_fix
LeiWang1999 Oct 24, 2025
9ecf41a
atomic_fix
LeiWang1999 Oct 24, 2025
68b6ada
phaseout py39
LeiWang1999 Oct 25, 2025
dc12ebc
optimize
LeiWang1999 Oct 25, 2025
751dbc7
optimize
LeiWang1999 Oct 25, 2025
fde6a50
Merge branch 'main' of https://github.com/tile-ai/tilelang into tvm_r…
LeiWang1999 Oct 25, 2025
38f7e49
lint fix
LeiWang1999 Oct 25, 2025
46dfea1
do not clean cache
LeiWang1999 Oct 26, 2025
33d6ad1
do not clean cache
LeiWang1999 Oct 26, 2025
c768919
Merge branch 'main' of https://github.com/tile-ai/tilelang into tvm_r…
LeiWang1999 Oct 27, 2025
f9a97c7
[Minor] Minor update for Python versions and dependencies
XuehaiPan Oct 27, 2025
c9d64fa
[Lint] fix lint for py39
XuehaiPan Oct 27, 2025
89df129
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 27, 2025
4abf1d0
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
452e40f
[Build][CI] Sync CI changes from upstream/sdist
XuehaiPan Oct 27, 2025
b9306ab
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
40a6138
[Build][CI] Update `repair-wheel-command`
XuehaiPan Oct 27, 2025
565fa97
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 27, 2025
305c86a
[Minor] update abi3audit result format
XuehaiPan Oct 27, 2025
4d56db2
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
daf6c55
[BugFix] fix build
XuehaiPan Oct 27, 2025
66ac445
[Lint] fix lint for ROCm
XuehaiPan Oct 27, 2025
fbc250e
[BugFix] set rpath for libtvm and libtvm_runtime
XuehaiPan Oct 27, 2025
6bbf6aa
[Deps] pin apache-tvm-ffi version
XuehaiPan Oct 27, 2025
4aa24ea
[Build] set Python 3.9 Limited API for Cython target
XuehaiPan Oct 27, 2025
1b21e95
[Build] set Python 3.9 Limited API for Cython target
XuehaiPan Oct 27, 2025
1de65aa
[Deps] Restore Python 3.8 support
XuehaiPan Oct 27, 2025
6e8ad0d
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 27, 2025
f4baf32
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 28, 2025
ba13761
[Build] use `apache-tvm-ffi`'s `libtvm_ffi`
XuehaiPan Oct 28, 2025
5be4057
[BugFix] use `;` as delimiter for RPATH on macOS
XuehaiPan Oct 28, 2025
5e2ed4f
[BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`
XuehaiPan Oct 28, 2025
2d5faa8
[Build] support `sccache` if available
XuehaiPan Oct 28, 2025
9a48282
[Build] add CIBW import test
XuehaiPan Oct 28, 2025
cd9ab57
[Build][CI] enable ccache for CIBW on Linux
XuehaiPan Oct 28, 2025
1807298
[BugFix] set rpath for libtvm and libtvm_runtime
XuehaiPan Oct 28, 2025
aa4eb5d
Revert "[Build][CI] enable ccache for CIBW on Linux"
XuehaiPan Oct 28, 2025
024c1e3
[CI] fix perfbench bot
XuehaiPan Oct 28, 2025
f276a26
[BugFix] use Python 3.9 to build wheel
XuehaiPan Oct 28, 2025
8a23c08
[Minor] update perfbench bot envs
XuehaiPan Oct 28, 2025
8424e99
[BugFix] fix CIBW environment on Linux
XuehaiPan Oct 28, 2025
9ef647c
[CI] skip import test on CentOS 7
XuehaiPan Oct 28, 2025
4bf2524
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 29, 2025
c9e9191
[CI] use Python urllib to download file instead of Wget
XuehaiPan Oct 29, 2025
74ec78c
Merge remote-tracking branch 'upstream/main' into tvm_rebase
XuehaiPan Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
InheritParentConfig: true
ExtraArgs: ['-v']
ExtraArgs: []
FormatStyle: file
UseColor: true
WarningsAsErrors: '*'
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
UV_INDEX_STRATEGY: "unsafe-best-match"
UV_HTTP_TIMEOUT: "600"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated
Expand All @@ -44,15 +46,15 @@ jobs:
submodules: recursive

- name: Setup Python 3.8
id: setup-py38
id: setup-pylowest
uses: actions/setup-python@v6
with:
python-version: "3.8" # use lowest supported version for linting
update-environment: false

- name: Check AST with Python 3.8
run: |
"${{ steps.setup-py38.outputs.python-path }}" -m compileall -q -f tilelang
"${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang

- name: Setup Python 3.12
uses: actions/setup-python@v6
Expand Down
13 changes: 5 additions & 8 deletions .github/workflows/dist.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,11 @@ jobs:
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" }
python-version:
- "3.8"
# TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8.
# - "3.9"
# - "3.10"
# - "3.11"
# - "3.12"
# - "3.13"
# - "3.14"
# Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# Only build wheels against Python 3.8 Limited API to save CI resources.
# FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
- "3.9"
fail-fast: false
timeout-minutes: 120
runs-on: ${{ matrix.target.runner }}
Expand Down
18 changes: 17 additions & 1 deletion .github/workflows/pr-perfbench-bot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ concurrency:
group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true # always cancel in-progress

env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated

jobs:
perfbench:
name: Benchmark between PR and main
Expand All @@ -31,7 +42,12 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.9"
python-version: "3.12"
update-environment: true
cache: pip
cache-dependency-path: |
pyproject.toml
requirements*.txt

- name: Install merged version
run: |
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 5bf17a to 0f1eba
42 changes: 27 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}")
# Warning came from tvm submodule
string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference")
endif()

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git")
Expand Down Expand Up @@ -36,9 +41,18 @@ endif()

find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
else()
find_program(SCCACHE_PROGRAM sccache)
if(SCCACHE_PROGRAM)
message(STATUS "Using sccache: ${SCCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
endif()
endif()

# Configs
Expand Down Expand Up @@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc
)
Expand Down Expand Up @@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper

# let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path")
else()
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN")
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
endif()

install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib)

# Copy tvm cython ext for wheels
# TODO: not necessary for editable builds
if(TVM_BUILD_FROM_SOURCE)
add_dependencies(tilelang tvm_cython)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
endif()
install(
TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib
)
11 changes: 10 additions & 1 deletion cmake/load_tvm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ endif()

set(TVM_INCLUDES
${TVM_SOURCE}/include
${TVM_SOURCE}/ffi/include
${TVM_SOURCE}/src
${TVM_SOURCE}/3rdparty/dlpack/include
${TVM_SOURCE}/3rdparty/dmlc-core/include
)

if(EXISTS ${TVM_SOURCE}/ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include)
elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
endif()

if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
endif()
71 changes: 37 additions & 34 deletions examples/gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,34 @@ TileLang is a domain-specific language designed to simplify the process of writi

## Table of Contents

1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
- [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Verifying Correctness](#verifying-correctness)
- [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
- [References](#references)

---

## Getting Started

### Prerequisites

- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification)
- **tilelang**
- **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples)

### Installation
Expand Down Expand Up @@ -87,34 +90,34 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo

### Code Walkthrough

1. **Define the Kernel Launch Configuration:**
1. **Define the Kernel Launch Configuration:**
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.

2. **Shared Memory Allocation:**
2. **Shared Memory Allocation:**
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.

3. **Local Fragment Accumulation:**
3. **Local Fragment Accumulation:**
```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
Partial results are stored in registers (or local memory) to reduce writes to global memory.

4. **Pipelined Loading and GEMM:**
4. **Pipelined Loading and GEMM:**
```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...)
T.gemm(...)
```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.

5. **Copy Out the Results:**
5. **Copy Out the Results:**
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
Expand Down Expand Up @@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
```

**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.

---

Expand Down Expand Up @@ -247,7 +250,7 @@ print("Results match!")

## Fine-grained MMA Computations

For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.

### Example Workflow

Expand Down Expand Up @@ -394,10 +397,10 @@ def tl_matmul(
]
```

1. **Set Up Tile Sizes and Thread Bindings**
1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).

2. **Allocate Warp-local Fragments**
2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
Expand All @@ -406,7 +409,7 @@ def tl_matmul(
```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.

3. **Load Data via `ldmatrix`**
3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python
for ki in T.serial(0, (block_K // micro_size_k)):
Expand All @@ -418,7 +421,7 @@ def tl_matmul(
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.

4. **Perform the MMA Instruction**
4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
Expand All @@ -429,7 +432,7 @@ def tl_matmul(
```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.

5. **Store Results via `stmatrix`**
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared)
Expand All @@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma

## References

- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
3 changes: 3 additions & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2
fi

# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0

# If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit
Expand Down
Loading
Loading