Skip to content

Commit

Permalink
Added C++ and CUDA bindings to tile_matmul for 1.1.2 Release (#66)
Browse files Browse the repository at this point in the history
* Added C++ and CUDA bindings for `memtorch.bh.crossbar.Tile.tile_matmul`.
* Added `Eigen` integration with C++ and CUDA bindings.
* Modularized C++ and CUDA `quantize` bindings.
  • Loading branch information
coreylammie committed Jul 8, 2021
1 parent aa54b63 commit 055e036
Show file tree
Hide file tree
Showing 42 changed files with 1,155 additions and 392 deletions.
1 change: 1 addition & 0 deletions .clang-format-ignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
memtorch/cu/*
12 changes: 9 additions & 3 deletions .github/workflows/build_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v2
with:
submodules: recursive
- name: Create release
id: create_release
uses: actions/create-release@v1
Expand All @@ -32,6 +34,8 @@ jobs:
os: [windows-2019, macOS-10.15, ubuntu-20.04]
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- uses: actions/setup-python@v2
with:
python-version: 3.9
Expand All @@ -41,9 +45,9 @@ jobs:
- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
CIBW_BEFORE_BUILD_WINDOWS: pip3 install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed
CIBW_BEFORE_BUILD_MACOS: pip3 install torch==1.8.1 -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed
CIBW_BEFORE_BUILD_LINUX: pip3 install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed
CIBW_BEFORE_BUILD_WINDOWS: pip3 install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed
CIBW_BEFORE_BUILD_MACOS: pip3 install torch==1.9.0 -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed
CIBW_BEFORE_BUILD_LINUX: pip3 install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed
CIBW_REPAIR_WHEEL_COMMAND: ""
CIBW_BUILD: cp37-* cp38-* cp39-*
CIBW_SKIP: "*-manylinux_i686 *-win32"
Expand All @@ -66,6 +70,8 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- uses: actions/setup-python@v2
name: Install Python
with:
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/push_pull.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
name: CI
on: [push, pull_request]
on:
push:
tags-ignore:
- "v*"
pull_request:
jobs:
linter:
name: Validate code formatting
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ MemTorch_cpu.egg-info/
memtorch/examples/reproduce/*.csv
tmp/
**/.pytest_cache/
.vscode/
.eggs/
.vscode/
tmp.py
9 changes: 3 additions & 6 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[submodule "memtorch/submodules/pytorch-playground"]
path = memtorch/submodules/pytorch-playground
url = https://github.com/coreylammie/pytorch-playground
[submodule "memtorch/submodules/memtorch/submodules/pytorch-playground"]
path = memtorch/submodules/memtorch/submodules/pytorch-playground
url = https://github.com/coreylammie/pytorch-playground
[submodule "memtorch/submodules/eigen"]
path = memtorch/submodules/eigen
url = https://gitlab.com/libeigen/eigen.git
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 21.6b0
hooks:
- id: black
language_version: python3
- repo: https://github.com/timothycrosley/isort
rev: 5.8.0
rev: 5.9.1
hooks:
- id: isort
- repo: https://github.com/pocc/pre-commit-hooks
rev: python
rev: v1.1.1
hooks:
- id: clang-format
24 changes: 23 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
- Transitioned from TravisCI to GitHub Actions.
## Added

1. C++ and CUDA bindings for `memtorch.bh.crossbar.Tile.tile_matmul`.

Using an NVIDIA GeForce GTX 1080, a tile shape of (25, 25), and two tensors of size (500, 500), the runtime of `tile_matmul` without quantization support is reduced by 2.45x and 5.48x, for CPU-bound and GPU-bound operation, respectively. With an ADC resolution of 4 bits and an overflow rate of 0.0, the runtime of `tile_matmul` with quantization support is reduced by 2.30x and 105.27x, for CPU-bound and GPU-bound operation, respectively.

| Implementation | Runtime Without Quantization Support (s) | Runtime With Quantization Support (s) |
| ---------------------- | ---------------------------------------- | ------------------------------------- |
| Pure Python (Previous) | 6.917784 | 27.099764 |
| C++ (CPU-bound) | 2.822265 | 11.736974 |
| CUDA (GPU-bound) | 1.262861 | 0.2574267 |

3. `Eigen` integration with C++ and CUDA bindings.
4. Additional unit tests.

## Enhanced

1. Modularized C++ and CUDA `quantize` bindings.
2. Enhanced functionality of `naive_progam` and added additional input arguments to dictate logic for stuck devices.

## Fixed

1. Removed debugging code from `naive_progam`.
5 changes: 4 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
include memtorch/cu/quantize/gpu.cuh
graft memtorch/submodules/eigen
include memtorch/cpp/*.h
include memtorch/cu/*.h
include memtorch/cu/*.cuh
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
author = "Corey Lammie"

# The full version, including alpha/beta/rc tags
release = "1.1.1"
release = "1.1.2"
autodoc_inherit_docstrings = False

# -- General configuration ---------------------------------------------------
Expand Down
95 changes: 56 additions & 39 deletions memtorch/bh/Quantize.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,74 @@
# Wrapper for the pytorch-playground quant.py script
import importlib
import copy

utee = importlib.import_module(".utee", "memtorch.submodules.pytorch-playground")
import numpy as np
import torch

quant_methods = ["linear", "log", "tanh"]
import memtorch
import memtorch_bindings

quant_methods = ["linear", "log"]

def quantize(input, bits, overflow_rate, quant_method="linear", min=None, max=None):

def quantize(
tensor,
quant,
overflow_rate=0.0,
quant_method=None,
min=float("nan"),
max=float("nan"),
override_original=False,
):
"""Method to quantize a tensor.
Parameters
----------
input : tensor
tensor : tensor
Input tensor.
bits : int
Bit width.
overflow_rate : float
Overflow rate threshold for linear quanitzation.
quant_method : str
Quantization method. Must be in ['linear', 'log', 'tanh'].
min : float
Minimum value to clip values to.
max : float
Maximum value to clip values to.
quant : int
Bit width (if quant_method is not None) or the number of discrete quantization levels (if quant_method is None).
overflow_rate : float, optional
Overflow rate threshold for linear quantization.
quant_method : str, optional
Quantization method. Must be in quant_methods.
min : float or tensor, optional
Minimum value(s) to clip numbers to.
max : float or tensor, optional
Maximum value(s) to clip numbers to.
override_original : bool, optional
Whether to override the original tensor (True) or not (False).
Returns
-------
tensor
Quantized tensor.
"""
assert type(bits) == int and bits > 0, "bits must be an integer > 0."
assert overflow_rate >= 0 and overflow_rate <= 1, "overflow_rate value invalid."
assert quant_method in quant_methods, "quant_method is not valid."
pass
if min is not None:
input = input.clip(min=min)

if max is not None:
input = input.clip(max=max)

if torch.unique(input).numel() == 1:
return input

if quant_method == "linear":
sf = bits - 1 - utee.compute_integral_part(input, overflow_rate)
return utee.linear_quantize(input, sf, bits)
elif quant_method == "log":
log_abs_input = torch.log(torch.abs(input))
log_abs_input[log_abs_input == float("-inf")] = 1e-12
sf = bits - 1 - utee.compute_integral_part(log_abs_input, overflow_rate)
return utee.log_linear_quantize(input, sf, bits)
elif quant_method == "tanh":
return utee.tanh_quantize(input, bits)
device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
assert (
overflow_rate >= 0 and overflow_rate <= 1
), "overflow_rate must be >= 0 and <= 1."
assert (
type(quant) == int and quant > 0
), "The bit width or number of discrete quantization levels must be a positive integer."
if type(min) == int:
min = float(min)
if type(max) == int:
max = float(max)
if not override_original:
tensor = copy.deepcopy(tensor)
if quant_method is not None:
assert quant_method in quant_methods, "quant_method is invalid."
tensor = tensor.cpu()
memtorch_bindings.quantize(
tensor,
bits=quant,
overflow_rate=overflow_rate,
quant_method=quant_methods.index(quant_method),
min=min,
max=max,
)
else:
tensor = tensor.cpu()
memtorch_bindings.quantize(tensor, n_quant_levels=quant, min=min, max=max)

return tensor.to(device)
9 changes: 4 additions & 5 deletions memtorch/bh/crossbar/Crossbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import torch.nn as nn

import memtorch
from memtorch.utils import pad_tensor

from .Tile import gen_tiles, tile_matmul
from .Tile import gen_tiles


@unique
Expand Down Expand Up @@ -443,7 +442,7 @@ def simulate_matmul(
ADC_overflow_rate : float
Overflow rate threshold for linear quanitzation (if ADC_resolution is not None).
quant_method:
Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None.
Quantization method. Must be in memtorch.bh.Quantize.quant_methods.
Returns
-------
Expand Down Expand Up @@ -497,7 +496,7 @@ def simulate_matmul(
if quant_method is not None:
mat_res_ = memtorch.bh.Quantize.quantize(
mat_res_,
bits=ADC_resolution,
quant=ADC_resolution,
overflow_rate=ADC_overflow_rate,
quant_method=quant_method,
)
Expand Down Expand Up @@ -552,7 +551,7 @@ def tile_simulate_matmul_row(
if quant_method is not None:
partial_sum[j] += memtorch.bh.Quantize.quantize(
mat_res.squeeze(),
bits=ADC_resolution,
quant=ADC_resolution,
overflow_rate=ADC_overflow_rate,
quant_method=quant_method,
)
Expand Down
Loading

0 comments on commit 055e036

Please sign in to comment.