Skip to content

Commit

Permalink
Added ADC and Variable Input Voltage Range Support, and Modularized A…
Browse files Browse the repository at this point in the history
…ll memtorch.mn Modules (#30)
  • Loading branch information
coreylammie committed Feb 5, 2021
1 parent 6496cb3 commit a0d82ce
Show file tree
Hide file tree
Showing 23 changed files with 357 additions and 112 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "memtorch/submodules/pytorch-playground"]
path = memtorch/submodules/pytorch-playground
url = https://github.com/coreylammie/pytorch-playground
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ install:
- sudo apt-get install -y ninja-build
- python -m pip install --upgrade pip
- python -m pip install -U pytest
- python -m pip install numpy pandas torch torchvision matplotlib seaborn sklearn codecov pytest-cov travispls
- python -m pip install numpy pandas torch torchvision matplotlib seaborn sklearn codecov pytest-cov travispls ipython
- python setup.py install
script:
- travis-pls -m 1200 pytest -s --cov=memtorch
- travis-pls -m 5000 pytest -s --cov=memtorch
after_success:
- codecov
Binary file removed Animate.gif
Binary file not shown.
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@

MemTorch is a *Simulation Framework for Memristive Deep Learning Systems* which integrates directly with the well-known *PyTorch* Machine Learning (ML) library, which is presented in *MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems*, which has been released [here](https://arxiv.org/abs/2004.10971).

<p align="center">
<img src="Animate.gif" width=100%>
<br>
<em>Example usage.</em>
</p>

## MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems
> Corey Lammie, Wei Xiang, Bernabé Linares-Barranco, and Mostafa Rahimi Azghadi<br>
>
Expand All @@ -29,7 +23,7 @@ MemTorch is a *Simulation Framework for Memristive Deep Learning Systems* which
To install MemTorch from source:

```
git clone https://github.com/coreylammie/MemTorch
git clone --recursive https://github.com/coreylammie/MemTorch
cd MemTorch
python setup.py install
```
Expand Down
10 changes: 10 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
coverage:
status:
project:
default:
target: 90%
threshold: 0.5%
patch:
default:
target: 85%
threshold: 0.5%
1 change: 1 addition & 0 deletions memtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from memtorch.mn import *
from memtorch.utils import *
from memtorch.map import *
from memtorch.submodules import *
50 changes: 50 additions & 0 deletions memtorch/bh/Quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Wrapper for the pytorch-playground quant.py script
import importlib
utee = importlib.import_module('.utee', 'memtorch.submodules.pytorch-playground')
import torch
import numpy as np
quant_methods = ['linear', 'log', 'tanh']

def quantize(input, bits, overflow_rate, quant_method='linear', min=None, max=None):
"""Method to quantize a tensor.
Parameters
----------
input : tensor
Input tensor.
bits : int
Bit width.
overflow_rate : float
Overflow rate threshold for linear quanitzation.
quant_method : str
Quantization method. Must be in ['linear', 'log', 'tanh'].
min : float
Minimum value to clip values to.
max : float
Maximum value to clip values to.
Returns
-------
tensor
Quantized tensor.
"""
assert type(bits) == int and bits > 0, 'bits must be an integer > 0.'
assert overflow_rate >= 0 and overflow_rate <= 1, 'overflow_rate value invalid.'
assert quant_method in quant_methods, 'quant_method is not valid.'
if min is not None:
input = input.clip(min=min)

if max is not None:
input = input.clip(max=max)

if quant_method == 'linear':
sf = bits - 1 - utee.compute_integral_part(input, overflow_rate)
return utee.linear_quantize(input, sf, bits)
elif quant_method == 'log':
log_abs_input = torch.log(torch.abs(input))
log_abs_input[log_abs_input == float('-inf')] = 1e-12
sf = bits - 1 - utee.compute_integral_part(log_abs_input, overflow_rate)
return utee.log_linear_quantize(input, sf, bits)
elif quant_method == 'tanh':
return utee.tanh_quantize(input, bits)
1 change: 1 addition & 0 deletions memtorch/bh/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .memristor import *
from .crossbar import *
from .StochasticParameter import *
from .Quantize import *
104 changes: 58 additions & 46 deletions memtorch/bh/crossbar/Crossbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, memristor_model, memristor_model_params, shape, tile_shape=No

self.g_np = np.vectorize(lambda x: x.g)
self.update(from_devices=False)
self.max_abs_conductance = torch.abs(self.conductance_matrix).flatten()

def update(self, from_devices=True, parallelize=False):
"""Method to update either the layers conductance_matrix or each devices conductance state.
Expand All @@ -89,6 +90,7 @@ def update(self, from_devices=True, parallelize=False):
"""
if from_devices:
self.conductance_matrix = torch.tensor(self.g_np(self.devices)).to(self.device)
self.max_abs_conductance = torch.abs(self.conductance_matrix).flatten().max()
else:
if parallelize:
def write_conductance(device, conductance):
Expand Down Expand Up @@ -137,6 +139,7 @@ def write_conductance_matrix(self, conductance_matrix, transistor=True, programm
conductance_matrix = torch.max(torch.min(conductance_matrix.to(self.device), max), min)
if transistor:
self.conductance_matrix = conductance_matrix
self.max_abs_conductance = torch.abs(self.conductance_matrix).flatten().max()
self.update(from_devices=False)
else:
assert programming_routine is not None, 'programming_routine must be defined if transistor is False.'
Expand All @@ -150,6 +153,8 @@ def write_conductance_matrix(self, conductance_matrix, transistor=True, programm
for j in range(0, self.columns):
self.devices = programming_routine(self, (i, j), conductance_matrix[i][j], **programming_routine_params)

self.update(from_devices=True)

def init_crossbar(weights, memristor_model, memristor_model_params, transistor, mapping_routine, programming_routine, programming_routine_params={}, p_l=None, scheme=Scheme.DoubleColumn, tile_shape=(128, 128)):
"""Method to initialise and construct memristive crossbars.
Expand Down Expand Up @@ -247,29 +252,49 @@ def out(crossbars, operation, idx=0, **kwargs):

return crossbars, out

def simulate_matmul(input, devices, nl=True, tiles_map=None, crossbar_shape=None):
def simulate_matmul(input, crossbar, nl=True, tiles_map=None, crossbar_shape=None, max_input_voltage=None, ADC_resolution=None, ADC_overflow_rate=0., quant_method=None):
"""Method to simulate non-linear IV device characterisitcs for a 2-D crossbar architecture given scaled inputs.
Parameters
----------
input : tensor
Scaled input tensor.
devices : numpy.ndarray
Devices to simulate.
crossbar : memtorch.bh.Crossbar
Crossbar containing devices to simulate.
nl : bool
Use lookup tables rather than simulating each device (True).
tiles_map: torch.tensor
Tiles map for devices if tile_shape is not None.
crossbar_shape : (int, int)
Crossbar shape if tile_shape is not None.
max_input_voltage : float
Maximum input voltage used to encode inputs. If None, inputs are unbounded.
ADC_resolution : int
ADC resolution (bit width). If None, quantization noise is not accounted for.
ADC_overflow_rate : float
Overflow rate threshold for linear quanitzation (if ADC_resolution is not None).
quant_method:
Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None.
Returns
-------
torch.tensor
Output tensor.
"""
devices = crossbar.devices
if max_input_voltage is not None:
output_max = crossbar.max_abs_conductance * max_input_voltage
else:
output_max = float('inf')

del crossbar
assert len(devices.shape) == 2 or len(devices.shape) == 3, 'Invalid devices shape.'
device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
if quant_method is not None:
assert ADC_resolution is not None and type(ADC_resolution) == int and ADC_resolution > 0, 'ADC resolution is invalid.'
assert quant_method in memtorch.bh.Quantize.quant_methods, 'quant_method is not valid.'
assert ADC_overflow_rate is not None, 'ADC_overflow_rate must be specified if quant_method is not None.'

input_rows, input_columns = input.shape
if len(devices.shape) == 2:
devices_rows, devices_columns = devices.shape
Expand All @@ -284,62 +309,49 @@ def simulate_matmul(input, devices, nl=True, tiles_map=None, crossbar_shape=None
for j in range(devices_columns):
for k in range(input_columns):
mat_res_[i][j] += devices[k][j].simulate(torch.Tensor([input[i][k]]).cpu(), return_current=True).item()

mat_res_ = torch.clamp(mat_res_, min=-output_max, max=output_max)
if quant_method is not None:
mat_res_ = memtorch.bh.Quantize.quantize(mat_res_, bits=ADC_resolution, overflow_rate=ADC_overflow_rate, quant_method=quant_method)
else:
assert tiles_map is not None and crossbar_shape is not None, 'tiles_map is not None.'
tile_shape = devices.shape[-2:]
input_tiles, input_tiles_map = gen_tiles(input, tile_shape, input=True)
mat_res_ = torch.zeros((input.shape[0], crossbar_shape[1])).to(device)
if nl:
def tile_simulate_matmul_row(input_row_tiles, input_tiles_map, devices, tiles_map, crossbar_shape):
device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
tile_shape = devices.shape[-2:]
partial_sum = torch.zeros((tiles_map.shape[1], tile_shape[1])).to(device)
for j in range(tiles_map.shape[1]):
for i in range(tiles_map.shape[0]):
tile_a = input_row_tiles[int(input_tiles_map[i])]
if len(tile_a.shape) == 1:
tile_a = tile_a.unsqueeze(0)

tile_b = devices[int(tiles_map[i][j])]
mat_res = torch.zeros((tile_a.shape[0], tile_b.shape[1])).to(device)
for ii in range(tile_a.shape[0]):
for jj in range(tile_b.shape[1]):
for kk in range(tile_b.shape[0]):
def tile_simulate_matmul_row(input_row_tiles, input_tiles_map, devices, tiles_map, crossbar_shape, nl, ADC_resolution, ADC_overflow_rate, quant_method):
device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
tile_shape = devices.shape[-2:]
partial_sum = torch.zeros((tiles_map.shape[1], tile_shape[1])).to(device)
for j in range(tiles_map.shape[1]):
for i in range(tiles_map.shape[0]):
tile_a = input_row_tiles[int(input_tiles_map[i])]
if len(tile_a.shape) == 1:
tile_a = tile_a.unsqueeze(0)

tile_b = devices[int(tiles_map[i][j])]
mat_res = torch.zeros((tile_a.shape[0], tile_b.shape[1])).to(device)
for ii in range(tile_a.shape[0]):
for jj in range(tile_b.shape[1]):
for kk in range(tile_b.shape[0]):
if nl:
mat_res[ii][jj] += tile_a[ii][kk].item() * tile_b[kk][jj].g

partial_sum[j] += mat_res.squeeze()

output_act = partial_sum.flatten()
output_act = output_act[:crossbar_shape[1]]
return output_act
else:
def tile_simulate_matmul_row(input_row_tiles, input_tiles_map, devices, tiles_map, crossbar_shape):
device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
tile_shape = devices.shape[-2:]
partial_sum = torch.zeros((tiles_map.shape[1], tile_shape[1])).to(device)
for j in range(tiles_map.shape[1]):
for i in range(tiles_map.shape[0]):
tile_a = input_row_tiles[int(input_tiles_map[i])]
if len(tile_a.shape) == 1:
tile_a = tile_a.unsqueeze(0)

tile_b = devices[int(tiles_map[i][j])]
mat_res = torch.zeros((tile_a.shape[0], tile_b.shape[1])).to(device)
for ii in range(tile_a.shape[0]):
for jj in range(tile_b.shape[1]):
for kk in range(tile_b.shape[0]):
else:
mat_res[ii][jj] += tile_b[kk][jj].simulate(torch.Tensor([tile_a[ii][kk]]).cpu(), return_current=True).item()

mat_res = torch.clamp(mat_res, min=-output_max, max=output_max)
if quant_method is not None:
partial_sum[j] += memtorch.bh.Quantize.quantize(mat_res.squeeze(), bits=ADC_resolution, overflow_rate=ADC_overflow_rate, quant_method=quant_method)
else:
partial_sum[j] += mat_res.squeeze()

output_act = partial_sum.flatten()
output_act = output_act[:crossbar_shape[1]]
return output_act
output_act = partial_sum.flatten()
output_act = output_act[:crossbar_shape[1]]
return output_act

if input_tiles.shape[-2] > 1:
for row_idx in range(input_tiles.shape[-2]):
mat_res_[row_idx] = tile_simulate_matmul_row(input_tiles[:, row_idx, :], input_tiles_map, devices, tiles_map, crossbar_shape)
mat_res_[row_idx] = tile_simulate_matmul_row(input_tiles[:, row_idx, :], input_tiles_map, devices, tiles_map, crossbar_shape, nl, ADC_resolution, ADC_overflow_rate, quant_method)
else:
mat_res_ = tile_simulate_matmul_row(input_tiles, input_tiles_map, devices, tiles_map, crossbar_shape)
mat_res_ = tile_simulate_matmul_row(input_tiles, input_tiles_map, devices, tiles_map, crossbar_shape, nl, ADC_resolution, ADC_overflow_rate, quant_method)

return mat_res_
32 changes: 26 additions & 6 deletions memtorch/bh/crossbar/Tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def gen_tiles(tensor, tile_shape, input=False):
tiles = torch.tensor([np.array(tile.array.cpu()) for tile in tiles])
return tiles, tiles_map

def tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape):
def tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape,
ADC_resolution=None, ADC_overflow_rate=0., quant_method=None):
""" Method to perform 2D matrix multiplication, given two sets of tiles.
Parameters
Expand All @@ -136,32 +137,51 @@ def tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_ti
Tiles map for matrix B.
mat_b_shape : (int, int)
Shape of matrix B.
ADC_resolution : int
ADC resolution (bit width). If None, quantization noise is not accounted for.
ADC_overflow_rate : float
Overflow rate threshold for linear quanitzation (if ADC_resolution is not None).
quant_method:
Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None.
Returns
-------
torch.tensor
Output tensor.
"""
def tile_matmul_row(mat_a_row_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape):
def tile_matmul_row(mat_a_row_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape,
ADC_resolution=None, ADC_overflow_rate=0., quant_method=None):
device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
if quant_method is not None:
assert ADC_resolution is not None and type(ADC_resolution) == int and ADC_resolution > 0, 'ADC resolution is invalid.'
assert quant_method in memtorch.bh.Quantize.quant_methods, 'quant_method is not valid.'
assert ADC_overflow_rate is not None, 'ADC_overflow_rate must be specified if quant_method is not None.'

tile_shape = mat_b_tiles.shape[-2:]
partial_sum = torch.zeros((mat_b_tiles_map.shape[1], tile_shape[1])).to(device)
for j in range(mat_b_tiles_map.shape[1]):
for i in range(mat_b_tiles_map.shape[0]):
tile_a = mat_a_row_tiles[int(mat_a_tiles_map[i])]
tile_b = mat_b_tiles[int(mat_b_tiles_map[i][j])]
partial_sum[j] += torch.matmul(tile_a.to(device), tile_b.to(device)).squeeze()
if quant_method is not None:
partial_sum[j] += memtorch.bh.Quantize.quantize(torch.matmul(tile_a.to(device), tile_b.to(device)).squeeze(), \
bits=ADC_resolution, overflow_rate=ADC_overflow_rate, quant_method=quant_method)
else:
partial_sum[j] += torch.matmul(tile_a.to(device), tile_b.to(device)).squeeze()

output_act = partial_sum.flatten()
output_act = output_act[:mat_b_shape[1]]
return output_act

assert mat_a_tiles.shape[-1] == mat_b_tiles.shape[-2] and len(mat_a_tiles.shape) == 3 and len(mat_b_tiles.shape) == 3 and mat_a_tiles.shape[-2] != 0, 'Incompatible tile shapes used.'
assert mat_a_tiles.shape[-1] == mat_b_tiles.shape[-2] and len(mat_a_tiles.shape) == 3 \
and len(mat_b_tiles.shape) == 3 and mat_a_tiles.shape[-2] != 0, 'Incompatible tile shapes used.'
result = torch.zeros((mat_a_shape[0], mat_b_shape[1]))
if mat_a_tiles.shape[-2] > 1:
for row_idx in range(mat_a_tiles.shape[-2]):
result[row_idx] = tile_matmul_row(mat_a_tiles[:, row_idx, :], mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape)
result[row_idx] = tile_matmul_row(mat_a_tiles[:, row_idx, :], mat_a_tiles_map, mat_a_shape, mat_b_tiles, \
mat_b_tiles_map, mat_b_shape, ADC_resolution, ADC_overflow_rate, quant_method)
else:
result = tile_matmul_row(mat_a_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape)
result = tile_matmul_row(mat_a_tiles, mat_a_tiles_map, mat_a_shape, mat_b_tiles, mat_b_tiles_map, mat_b_shape, \
ADC_resolution, ADC_overflow_rate, quant_method)

return result
Loading

0 comments on commit a0d82ce

Please sign in to comment.