From 0dfdbe16e1ac8bb3d4265f35ce28ce5a9f0ab262 Mon Sep 17 00:00:00 2001 From: Corey Lammie Date: Tue, 21 Sep 2021 20:47:44 +1000 Subject: [PATCH] Added Support for Modeling Source and Line Resistances for 1.1.4 Release (#98) * Added support for modeling source and line resistances for passive crossbars/tiles. * Added C++ and CUDA bindings for modeling source and line resistances for passive crossbars/tiles*. * Added a new MemTorch logo to `README.md`. * Added the `set_cuda_malloc_heap_size` routine to patched `torch.mn` modules. * Added unit tests for source and line resistance modeling. * Updated ReadTheDocs documentation. * Transitioned from Gitter to GitHub Discussions for general discussion. ***Note** It is strongly suggested to set `cuda_malloc_heap_size` using `m.set_cuda_malloc_heap_size` manually when simulating source and line resistances using CUDA bindings. --- .gitignore | 3 +- CHANGELOG.md | 29 +- README.md | 4 +- docs/conf.py | 2 +- docs/index.rst | 2 +- logo.svg | 112 +++++++ memtorch/bh/crossbar/Crossbar.py | 5 +- memtorch/bh/crossbar/Passive.py | 253 ++++++++++++++++ memtorch/bh/crossbar/Tile.py | 231 +++++++++++--- memtorch/bh/crossbar/__init__.py | 1 + .../empirical_metal_oxide_RRAM.py | 2 +- memtorch/cpp/bindings.cpp | 7 +- memtorch/cpp/inference.cpp | 85 +++++- memtorch/cpp/solve_passive.cpp | 272 +++++++++++++++++ memtorch/cpp/solve_passive.h | 11 + memtorch/cpp/solve_sparse_linear.cpp | 24 ++ memtorch/cpp/solve_sparse_linear.h | 7 + memtorch/cpp/tile_matmul.cpp | 127 +++++++- memtorch/cpp/tile_matmul.h | 10 + memtorch/cu/bindings.cpp | 5 +- memtorch/cu/inference.cpp | 92 +++++- memtorch/cu/solve_passive.cpp | 24 ++ memtorch/cu/solve_passive.cuh | 229 ++++++++++++++ memtorch/cu/solve_passive.h | 1 + memtorch/cu/solve_passive_kernels.cu | 284 ++++++++++++++++++ memtorch/cu/solve_passive_kernels.cuh | 3 + memtorch/cu/solve_sparse_linear.cpp | 24 ++ memtorch/cu/solve_sparse_linear.h | 6 + memtorch/cu/tile_matmul.cpp | 63 +++- memtorch/cu/tile_matmul_kernels.cu | 244 +++++++++++++-- memtorch/cu/tile_matmul_kernels.cuh | 3 +- memtorch/cu/utils.cuh | 7 +- memtorch/mn/Conv1d.py | 61 +++- memtorch/mn/Conv2d.py | 61 +++- memtorch/mn/Conv3d.py | 61 +++- memtorch/mn/Linear.py | 53 +++- memtorch/mn/Module.py | 15 + memtorch/version.py | 2 +- setup.py | 26 +- tests/test_cpp_extensions.py | 47 --- tests/test_networks.py | 14 +- 41 files changed, 2315 insertions(+), 197 deletions(-) create mode 100644 logo.svg create mode 100644 memtorch/bh/crossbar/Passive.py create mode 100644 memtorch/cpp/solve_passive.cpp create mode 100644 memtorch/cpp/solve_passive.h create mode 100644 memtorch/cpp/solve_sparse_linear.cpp create mode 100644 memtorch/cpp/solve_sparse_linear.h create mode 100644 memtorch/cu/solve_passive.cpp create mode 100644 memtorch/cu/solve_passive.cuh create mode 100644 memtorch/cu/solve_passive.h create mode 100644 memtorch/cu/solve_passive_kernels.cu create mode 100644 memtorch/cu/solve_passive_kernels.cuh create mode 100644 memtorch/cu/solve_sparse_linear.cpp create mode 100644 memtorch/cu/solve_sparse_linear.h delete mode 100644 tests/test_cpp_extensions.py diff --git a/.gitignore b/.gitignore index f7366f3b..8b4e9ada 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ venv/ /cuda_quantization.cpython-38-x86_64-linux-gnu.so /quantization.cpython-38-x86_64-linux-gnu.so .idea/ -/memtorch.egg-info/ \ No newline at end of file +/memtorch.egg-info/ +*.csv \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 351247db..50b3e7eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,17 @@ ## Added -1. Added another version of the Data Driven Model defined using `memtorch.bh.memrsitor.Data_Driven2021`. -2. Added CPU- and GPU-bound C++ bindings for `gen_tiles`. -3. Exposed `use_bindings`. -4. Added unit tests for `use_bindings`. -5. Added `exemptAssignees` tag to `scale.yml`. -6. Created `memtorch.map.Input` to encapsulate customizable input scaling methods. -7. Added the `force_scale` input argument to the default scaling method to specify whether inputs are force scaled if they do not exceed `max_input_voltage`. -8. Added CPU and GPU bindings for `tiled_inference`. +1. Added Patching Support for `torch.nn.Sequential` containers. +2. Added support for modeling source and line resistances for passive crossbars/tiles. +3. Added C++ and CUDA bindings for modeling source and line resistances for passive crossbars/tiles\*. +4. Added a new MemTorch logo to `README.md` +5. Added the `set_cuda_malloc_heap_size` routine to patched `torch.mn` modules. +6. Added unit tests for source and line resistance modeling. +7. Relaxed requirements for programming passive crossbars/tiles. -## Enhanced - -1. Modularized input scaling logic for all layer types. -2. Modularized `tile_inference` for all layer types. -3. Updated ReadTheDocs documentation. +**\*Note** it is strongly suggested to set `cuda_malloc_heap_size` using `m.set_cuda_malloc_heap_size` manually when simulating source and line resistances using CUDA bindings. -## Fixed +## Enhanced -1. Fixed GitHub Action Workflows for external pull requests. -2. Fixed error raised by `memtorch.map.Parameter` when `p_l` is defined. -3. Fixed semantic error in `memtorch.cpp.gen_tiles`. +1. Modularized patching logic in `memtorch.bh.nonideality.NonIdeality` and `memtorch.mn.Module`. +2. Updated `ReadTheDocs` documentation. +3. Transitioned from `Gitter` to `GitHub Discussions` for general discussion. diff --git a/README.md b/README.md index 0ac44712..458cd8c7 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@


- MemTorch + MemTorch

[![](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/) ![](https://img.shields.io/badge/license-GPL-blue.svg) ![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3760695.svg) -[![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/memtorch/community) +[![GitHub Discussions](https://img.shields.io/badge/chat-discussions-ff69b4)](https://github.com/coreylammie/MemTorch/discussions/97) ![](https://readthedocs.org/projects/pip/badge/?version=latest) [![CI](https://github.com/coreylammie/MemTorch/actions/workflows/push_pull.yml/badge.svg)](https://github.com/coreylammie/MemTorch/actions/workflows/push_pull.yml) [![codecov](https://codecov.io/gh/coreylammie/MemTorch/branch/master/graph/badge.svg)](https://codecov.io/gh/coreylammie/MemTorch) diff --git a/docs/conf.py b/docs/conf.py index 13b92f5d..f02cbfa8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ author = "Corey Lammie" # The full version, including alpha/beta/rc tags -release = "1.1.3" +release = "1.1.4" autodoc_inherit_docstrings = False # -- General configuration --------------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index fbaf3759..e2c3a415 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,4 +23,4 @@ We provide documentation in the form of a complete Python API, and numerous inte memtorch tutorials - Discuss MemTorch on Gitter + Discuss MemTorch on GitHub Discussions diff --git a/logo.svg b/logo.svg new file mode 100644 index 00000000..f3399edf --- /dev/null +++ b/logo.svg @@ -0,0 +1,112 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + diff --git a/memtorch/bh/crossbar/Crossbar.py b/memtorch/bh/crossbar/Crossbar.py index 8d85f030..ff062e10 100644 --- a/memtorch/bh/crossbar/Crossbar.py +++ b/memtorch/bh/crossbar/Crossbar.py @@ -203,16 +203,13 @@ def write_conductance_matrix( conductance_matrix = torch.max( torch.min(conductance_matrix.to(self.device), max), min ) - if transistor: + if transistor or programming_routine is None: self.conductance_matrix = conductance_matrix self.max_abs_conductance = ( torch.abs(self.conductance_matrix).flatten().max() ) self.update(from_devices=False) else: - assert ( - programming_routine is not None - ), "programming_routine must be defined if transistor is False." if self.tile_shape is not None: for i in range(0, self.devices.shape[0]): for j in range(0, self.devices.shape[1]): diff --git a/memtorch/bh/crossbar/Passive.py b/memtorch/bh/crossbar/Passive.py new file mode 100644 index 00000000..a96fb9f7 --- /dev/null +++ b/memtorch/bh/crossbar/Passive.py @@ -0,0 +1,253 @@ +import numpy as np +import torch + +import memtorch + +if "cpu" not in memtorch.__version__: + import memtorch_cuda_bindings + +import memtorch_bindings + + +def solve_passive( + conductance_matrix, + V_WL, + V_BL, + R_source, + R_line, + n_input_batches=None, + det_readout_currents=True, + use_bindings=True, + cuda_malloc_heap_size=None, +): + device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") + assert R_source != 0 or R_line != 0, "R_source or R_line must be non-zero." + assert R_source >= 0 and R_line >= 0, "R_source and R_line must be >=0." + if use_bindings: + if n_input_batches is None: + if "cpu" in memtorch.__version__: + return memtorch_bindings.solve_passive( + conductance_matrix, + V_WL, + V_BL, + R_source, + R_line, + det_readout_currents=det_readout_currents, + ) + else: + if cuda_malloc_heap_size is None: + return memtorch_cuda_bindings.solve_passive( + conductance_matrix, + V_WL, + V_BL, + R_source, + R_line, + det_readout_currents=det_readout_currents, + ) + else: + return memtorch_cuda_bindings.solve_passive( + conductance_matrix, + V_WL, + V_BL, + R_source, + R_line, + det_readout_currents=det_readout_currents, + cuda_malloc_heap_size=cuda_malloc_heap_size, + ) + else: + return memtorch_bindings.solve_passive( + conductance_matrix.cpu(), + V_WL.cpu(), + V_BL.cpu(), + R_source, + R_line, + n_input_batches=n_input_batches, + ).to(device) + else: + m = conductance_matrix.shape[0] + n = conductance_matrix.shape[1] + indices = torch.zeros(2, 8 * m * n - 2 * m - 2 * n, device=device) + values = torch.zeros(8 * m * n - 2 * m - 2 * n, device=device) + mn_range = torch.arange(m * n) + m_range = torch.arange(m) + n_range = torch.arange(n) + index = 0 + # A matrix + for i in range(m): + indices[0:2, index] = i * n + if R_source == 0: + values[index] = conductance_matrix[i, 0] + 1 / R_line + elif R_line == 0: + values[index] = conductance_matrix[i, 0] + 1 / R_source + else: + values[index] = conductance_matrix[i, 0] + 1 / R_source + 1 / R_line + + index += 1 + indices[0, index] = i * n + 1 + indices[1, index] = i * n + if R_line == 0: + values[index : index + 2] = 0 + else: + values[index : index + 2] = -1 / R_line + + index += 1 + indices[0, index] = i * n + indices[1, index] = i * n + 1 + index += 1 + indices[0:2, index] = i * n + (n - 1) + if R_line == 0: + values[index] = conductance_matrix[i, n - 1] + else: + values[index] = conductance_matrix[i, n - 1] + 1 / R_line + + index += 1 + for j in range(1, n - 1): + indices[0:2, index] = i * n + j + if R_line == 0: + values[index] = conductance_matrix[i, j] + else: + values[index] = conductance_matrix[i, j] + 2 / R_line + index += 1 + indices[0, index] = i * n + j + 1 + indices[1, index] = i * n + j + if R_line == 0: + values[index : index + 2] = 0 + else: + values[index : index + 2] = -1 / R_line + + index += 1 + indices[0, index] = i * n + j + indices[1, index] = i * n + j + 1 + index += 1 + + # B matrix + indices[0, index : index + (m * n)] = mn_range + indices[1, index : index + (m * n)] = ( + indices[0, index : index + (m * n)] + m * n + ) + values[index : index + (m * n)] = -conductance_matrix[ + n_range.repeat_interleave(m), n_range.repeat(m) + ] + index = index + (m * n) + # C matrix + indices[0, index : index + (m * n)] = mn_range + m * n + del mn_range + indices[1, index : index + (m * n)] = n * m_range.repeat( + n + ) + n_range.repeat_interleave(m) + values[index : index + (m * n)] = conductance_matrix[ + m_range.repeat_interleave(n), n_range.repeat(m) + ] + index = index + (m * n) + # D matrix + for j in range(n): + indices[0, index] = m * n + (j * m) + indices[1, index] = m * n + j + if R_line == 0: + values[index] = -conductance_matrix[0, j] + else: + values[index] = -1 / R_line - conductance_matrix[0, j] + + index += 1 + indices[0, index] = m * n + (j * m) + indices[1, index] = m * n + j + n + if R_line == 0: + values[index : index + 2] = 0 + else: + values[index : index + 2] = 1 / R_line + + index += 1 + indices[0, index : index + 2] = m * n + (j * m) + m - 1 + indices[1, index] = m * n + (n * (m - 2)) + j + index += 1 + indices[1, index] = m * n + (n * (m - 1)) + j + if R_source == 0: + values[index] = -conductance_matrix[m - 1, j] - 1 / R_line + elif R_line == 0: + values[index] = -1 / R_source - conductance_matrix[m - 1, j] + else: + values[index] = ( + -1 / R_source - conductance_matrix[m - 1, j] - 1 / R_line + ) + + index += 1 + indices[0, index : index + 3 * (m - 2)] = ( + m * n + (j * m) + m_range[1:-1].repeat_interleave(3) + ) + for i in range(1, m - 1): + indices[1, index] = m * n + (n * (i - 1)) + j + if R_line == 0: + values[index : index + 2] = 0 + else: + values[index : index + 2] = 1 / R_line + + index += 1 + indices[1, index] = m * n + (n * (i + 1)) + j + index += 1 + indices[1, index] = m * n + (n * i) + j + if R_line == 0: + values[index] = -conductance_matrix[i, j] + else: + values[index] = -conductance_matrix[i, j] - 2 / R_line + + index += 1 + + if n_input_batches is None: + E_matrix = torch.zeros(2 * m * n, device=device) + if R_source == 0: + E_matrix[m_range * n] = V_WL.to(device) # E_W values + E_matrix[m * n + (n_range + 1) * m - 1] = -V_BL.to(device) # E_B values + else: + # E_W values + E_matrix[m_range * n] = V_WL.to(device) / R_source + E_matrix[m * n + (n_range + 1) * m - 1] = ( + -V_BL.to(device) / R_source + ) # E_B values + + V = torch.linalg.solve( + torch.sparse_coo_tensor( + indices, values, (2 * m * n, 2 * m * n), device=device + ).to_dense(), + E_matrix, + ) + V_applied = torch.zeros((m, n), device=device) + for i in m_range: + V_applied[i, n_range] = V[n * i + n_range] - V[m * n + n * i + n_range] + if not det_readout_currents: + return V_applied + else: + return torch.sum(torch.mul(V_applied, conductance_matrix.to(device)), 0) + else: + out = torch.zeros(n_input_batches, n, device=device) + for i in range(n_input_batches): + E_matrix = torch.zeros(2 * m * n, device=device) + if R_source == 0: + E_matrix[m_range * n] = V_WL[i, :].to(device) # E_W values + E_matrix[m * n + (n_range + 1) * m - 1] = -V_BL[i, :].to( + device + ) # E_B values + else: + E_matrix[m_range * n] = ( + V_WL[i, :].to(device) / R_source + ) # E_W values + E_matrix[m * n + (n_range + 1) * m - 1] = ( + -V_BL[i, :].to(device) / R_source + ) # E_B values + + V = torch.linalg.solve( + torch.sparse_coo_tensor( + indices, values, (2 * m * n, 2 * m * n), device=device + ).to_dense(), + E_matrix, + ) + V_applied = torch.zeros((m, n), device=device) + for j in m_range: + V_applied[j, n_range] = ( + V[n * j + n_range] - V[m * n + n * j + n_range] + ) + + out[i, :] = torch.sum( + torch.mul(V_applied, conductance_matrix.to(device)), 0 + ) + + return out diff --git a/memtorch/bh/crossbar/Tile.py b/memtorch/bh/crossbar/Tile.py index 7e8fc37f..415181b7 100644 --- a/memtorch/bh/crossbar/Tile.py +++ b/memtorch/bh/crossbar/Tile.py @@ -158,9 +158,12 @@ def tile_matmul_row( mat_b_tiles, mat_b_tiles_map, mat_b_shape, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, + transistor=True, ): """Method to perform row-wise tile matrix multiplication, given two sets of tiles, using a pythonic approach. @@ -176,12 +179,18 @@ def tile_matmul_row( Tiles map for matrix B. mat_b_shape : int, int Shape of matrix B. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float Overflow rate threshold for linear quanitzation (if ADC_resolution is not None). quant_method: str Quantization method. Must be in memtorch.bh.Quantize.quant_methods. + transistor : bool + TBD. Returns ------- @@ -208,17 +217,44 @@ def tile_matmul_row( for i in range(mat_b_tiles_map.shape[0]): tile_a = mat_a_row_tiles[int(mat_a_tiles_map[i])] tile_b = mat_b_tiles[int(mat_b_tiles_map[i][j])] - if quant_method is not None: - partial_sum[j] += memtorch.bh.Quantize.quantize( - torch.matmul(tile_a.to(device), tile_b.to(device)).squeeze(), - quant=ADC_resolution, - overflow_rate=ADC_overflow_rate, - quant_method=quant_method, - ) + if transistor: + if quant_method is not None: + partial_sum[j] += memtorch.bh.Quantize.quantize( + torch.matmul(tile_a.to(device), tile_b.to(device)).squeeze(), + quant=ADC_resolution, + overflow_rate=ADC_overflow_rate, + quant_method=quant_method, + ) + else: + partial_sum[j] += torch.matmul( + tile_a.to(device), tile_b.to(device) + ).squeeze() else: - partial_sum[j] += torch.matmul( - tile_a.to(device), tile_b.to(device) - ).squeeze() + if quant_method is not None: + partial_sum[j] += memtorch.bh.crossbar.Passive.solve_passive( + tile_b, + tile_a, + torch.zeros(tile_b.shape[1]), + source_resistance, + line_resistance, + det_readout_currents=True, + use_bindings=False, + ) + else: + partial_sum[j] += memtorch.bh.Quantize.quantize( + memtorch.bh.crossbar.Passive.solve_passive( + tile_b, + tile_a, + torch.zeros(tile_b.shape[1]), + source_resistance, + line_resistance, + det_readout_currents=True, + use_bindings=False, + ), + quant=ADC_resolution, + overflow_rate=ADC_overflow_rate, + quant_method=quant_method, + ) output_act = partial_sum.flatten() output_act = output_act[: mat_b_shape[1]] @@ -232,9 +268,12 @@ def tile_matmul( mat_b_tiles, mat_b_tiles_map, mat_b_shape, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, + transistor=True, use_bindings=True, cuda_malloc_heap_size=50, ): @@ -254,12 +293,18 @@ def tile_matmul( Tiles map for matrix B. mat_b_shape : int, int Shape of matrix B. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float Overflow rate threshold for linear quanitzation (if ADC_resolution is not None). quant_method: str Quantization method. Must be in memtorch.bh.Quantize.quant_methods. + transistor : bool + TBD. use_bindings : bool Use C++/CUDA bindings to parallelize tile_matmul operations (True). cuda_malloc_heap_size : int @@ -276,6 +321,14 @@ def tile_matmul( and len(mat_b_tiles.shape) == 3 and mat_a_tiles.shape[-2] != 0 ), "Incompatible tile shapes used." + if source_resistance is not None and line_resistance is not None: + assert ( + source_resistance != 0 or line_resistance != 0 + ), "R_source or R_line must be non-zero." + assert ( + source_resistance >= 0 and line_resistance >= 0 + ), "R_source and R_line must be >=0." + if use_bindings: if quant_method is None: return memtorch_bindings.tile_matmul( @@ -313,9 +366,12 @@ def tile_matmul( mat_b_tiles, mat_b_tiles_map, mat_b_shape, + source_resistance, + line_resistance, ADC_resolution, ADC_overflow_rate, quant_method, + transistor, ) else: result = tile_matmul_row( @@ -324,14 +380,17 @@ def tile_matmul( mat_b_tiles, mat_b_tiles_map, mat_b_shape, + source_resistance, + line_resistance, ADC_resolution, ADC_overflow_rate, quant_method, + transistor, ) return result -def tiled_inference(input, m): +def tiled_inference(input, m, transistor): """Method to perform tiled inference. Parameters @@ -348,36 +407,137 @@ def tiled_inference(input, m): """ tiles_map = m.crossbars[0].tiles_map crossbar_shape = (m.crossbars[0].rows, m.crossbars[0].columns) + if m.source_resistance is not None and m.line_resistance is not None: + assert ( + m.source_resistance != 0 or m.line_resistance != 0 + ), "R_source or R_line must be non-zero." + assert ( + m.source_resistance >= 0 and m.line_resistance >= 0 + ), "R_source and R_line must be >=0." + if m.use_bindings: quant_method = m.quant_method if quant_method is None: - return memtorch_bindings.tiled_inference( - input, - input.shape, - m.tile_shape, - m.crossbar_operation( - m.crossbars, lambda crossbar: crossbar.conductance_matrix - ), - m.crossbars[0].tiles_map, - (m.crossbars[0].rows, m.crossbars[0].columns), - ) + if transistor: + if "cpu" in memtorch.__version__: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + m.crossbars[0].tiles_map, + (m.crossbars[0].rows, m.crossbars[0].columns), + ) + else: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + m.crossbars[0].tiles_map, + (m.crossbars[0].rows, m.crossbars[0].columns), + cuda_malloc_heap_size=m.cuda_malloc_heap_size, + ) + else: + if "cpu" in memtorch.__version__: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + m.crossbars[0].tiles_map, + (m.crossbars[0].rows, m.crossbars[0].columns), + m.source_resistance, + m.line_resistance, + ) + else: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + m.crossbars[0].tiles_map, + (m.crossbars[0].rows, m.crossbars[0].columns), + m.source_resistance, + m.line_resistance, + cuda_malloc_heap_size=m.cuda_malloc_heap_size, + ) else: assert ( quant_method in memtorch.bh.Quantize.quant_methods ), "quant_method is invalid." - return memtorch_bindings.tiled_inference( - input, - input.shape, - m.tile_shape, - m.crossbar_operation( - m.crossbars, lambda crossbar: crossbar.conductance_matrix - ), - tiles_map, - crossbar_shape, - m.ADC_resolution, - m.ADC_overflow_rate, - memtorch.bh.Quantize.quant_methods.index(quant_method), - ) + if transistor: + if "cpu" in memtorch.__version__: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + tiles_map, + crossbar_shape, + m.ADC_resolution, + m.ADC_overflow_rate, + memtorch.bh.Quantize.quant_methods.index(quant_method), + ) + else: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + tiles_map, + crossbar_shape, + m.ADC_resolution, + m.ADC_overflow_rate, + memtorch.bh.Quantize.quant_methods.index(quant_method), + cuda_malloc_heap_size=m.cuda_malloc_heap_size, + ) + else: + if "cpu" in memtorch.__version__: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + tiles_map, + crossbar_shape, + m.source_resistance, + m.line_resistance, + m.ADC_resolution, + m.ADC_overflow_rate, + memtorch.bh.Quantize.quant_methods.index(quant_method), + ) + else: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + tiles_map, + crossbar_shape, + m.source_resistance, + m.line_resistance, + m.ADC_resolution, + m.ADC_overflow_rate, + memtorch.bh.Quantize.quant_methods.index(quant_method), + cuda_malloc_heap_size=m.cuda_malloc_heap_size, + ) else: (input_tiles, input_tiles_map) = gen_tiles( input, @@ -394,8 +554,11 @@ def tiled_inference(input, m): ), tiles_map, crossbar_shape, + m.source_resistance, + m.line_resistance, m.ADC_resolution, m.ADC_overflow_rate, m.quant_method, + m.transistor, use_bindings=False, ) diff --git a/memtorch/bh/crossbar/__init__.py b/memtorch/bh/crossbar/__init__.py index cd048d25..8f664ac4 100644 --- a/memtorch/bh/crossbar/__init__.py +++ b/memtorch/bh/crossbar/__init__.py @@ -1,3 +1,4 @@ from .Crossbar import * +from .Passive import * from .Program import * from .Tile import * diff --git a/memtorch/bh/nonideality/endurance_retention_models/empirical_metal_oxide_RRAM.py b/memtorch/bh/nonideality/endurance_retention_models/empirical_metal_oxide_RRAM.py index 3538e57b..329d4301 100644 --- a/memtorch/bh/nonideality/endurance_retention_models/empirical_metal_oxide_RRAM.py +++ b/memtorch/bh/nonideality/endurance_retention_models/empirical_metal_oxide_RRAM.py @@ -99,7 +99,7 @@ def model_endurance_retention_gradual( """ return 10 ** ( p_3 * (p_1 * cell_size + p_2 * temperature_constant) * np.log10(x) - + np.log10(initial_resistance) + + torch.log10(initial_resistance) - p_3 * (p_1 * cell_size + p_2 * temperature_constant) * np.log10(threshold) ) diff --git a/memtorch/cpp/bindings.cpp b/memtorch/cpp/bindings.cpp index 793cc054..96411bc3 100644 --- a/memtorch/cpp/bindings.cpp +++ b/memtorch/cpp/bindings.cpp @@ -5,16 +5,13 @@ #include "gen_tiles.h" #include "inference.h" #include "quantize.h" +#include "solve_passive.h" #include "tile_matmul.h" -void quantize_bindings(py::module_ &); -void gen_tiles_bindings(py::module_ &); -void tile_matmul_bindings(py::module_ &); -void inference_bindings(py::module_ &); - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { quantize_bindings(m); gen_tiles_bindings(m); tile_matmul_bindings(m); inference_bindings(m); + solve_passive_bindings(m); } \ No newline at end of file diff --git a/memtorch/cpp/inference.cpp b/memtorch/cpp/inference.cpp index f23d0b66..cc3e2b07 100644 --- a/memtorch/cpp/inference.cpp +++ b/memtorch/cpp/inference.cpp @@ -16,6 +16,19 @@ at::Tensor tiled_inference(at::Tensor input, int input_shape[2], weight_tiles_map, weight_shape); } +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + float source_resistance, float line_resistance) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCPU)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, source_resistance, + line_resistance); +} + at::Tensor tiled_inference(at::Tensor input, int input_shape[2], int tile_shape[2], at::Tensor weight_tiles, at::Tensor weight_tiles_map, int weight_shape[2], @@ -30,8 +43,24 @@ at::Tensor tiled_inference(at::Tensor input, int input_shape[2], ADC_overflow_rate, quant_method); } +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + float source_resistance, float line_resistance, + int ADC_resolution, float ADC_overflow_rate, + int quant_method) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCPU)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, source_resistance, + line_resistance, ADC_resolution, ADC_overflow_rate, + quant_method); +} + void inference_bindings(py::module_ &m) { - // Binding without quantization support + // Binding without quantization support (transistor=True) m.def( "tiled_inference", [](at::Tensor input, std::tuple input_shape, @@ -53,7 +82,32 @@ void inference_bindings(py::module_ &m) { py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), py::arg("weight_tiles"), py::arg("weight_tiles_map"), py::arg("weight_shape")); - // Binding with quantization support + // Binding without quantization support (transistor=False) + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + float source_resistance, float line_resistance) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference(input, input_shape_array, tile_shape_array, + weight_tiles, weight_tiles_map, + weight_shape_array, source_resistance, + line_resistance); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("source_resistance"), + py::arg("line_resistance")); + // Binding with quantization support (transistor=True) m.def( "tiled_inference", [](at::Tensor input, std::tuple input_shape, @@ -78,4 +132,31 @@ void inference_bindings(py::module_ &m) { py::arg("weight_tiles"), py::arg("weight_tiles_map"), py::arg("weight_shape"), py::arg("ADC_resolution"), py::arg("ADC_overflow_rate"), py::arg("quant_method")); + // Binding with quantization support (transistor=False) + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + float source_resistance, float line_resistance, int ADC_resolution, + float ADC_overflow_rate, int quant_method) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference( + input, input_shape_array, tile_shape_array, weight_tiles, + weight_tiles_map, weight_shape_array, source_resistance, + line_resistance, ADC_resolution, ADC_overflow_rate, quant_method); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("source_resistance"), + py::arg("line_resistance"), py::arg("ADC_resolution"), + py::arg("ADC_overflow_rate"), py::arg("quant_method")); } \ No newline at end of file diff --git a/memtorch/cpp/solve_passive.cpp b/memtorch/cpp/solve_passive.cpp new file mode 100644 index 00000000..f40bb59a --- /dev/null +++ b/memtorch/cpp/solve_passive.cpp @@ -0,0 +1,272 @@ +#include +#include +#include + +#include +#include + +#include + +typedef Eigen::Triplet sparse_element; +typedef std::vector> triplet_vector; + +triplet_vector gen_ABCDE(Eigen::MatrixXf conductance_matrix_accessor, int m, + int n, float *V_WL_accessor, float *V_BL_accessor, + float R_source, float R_line, + triplet_vector ABCD_matrix, + float *E_matrix_accessor = NULL) { +// A, B, and E (partial) matrices +#pragma omp parallel for + for (int i = 0; i < m; i++) { + // A matrix + if (R_source == 0) { + ABCD_matrix.push_back(sparse_element( + i * n, i * n, conductance_matrix_accessor(i, 0) + 1.0f / R_line)); + } else if (R_line == 0) { + ABCD_matrix.push_back(sparse_element( + i * n, i * n, conductance_matrix_accessor(i, 0) + 1.0f / R_source)); + } else { + ABCD_matrix.push_back(sparse_element( + i * n, i * n, + conductance_matrix_accessor(i, 0) + 1.0f / R_source + 1.0f / R_line)); + } + if (R_line == 0) { + ABCD_matrix.push_back(sparse_element(i * n + 1, i * n, 0)); + ABCD_matrix.push_back(sparse_element(i * n, i * n + 1, 0)); + ABCD_matrix.push_back( + sparse_element(i * n + (n - 1), i * n + (n - 1), + conductance_matrix_accessor(i, n - 1))); + } else { + ABCD_matrix.push_back(sparse_element(i * n + 1, i * n, -1.0f / R_line)); + ABCD_matrix.push_back(sparse_element(i * n, i * n + 1, -1.0f / R_line)); + ABCD_matrix.push_back( + sparse_element(i * n + (n - 1), i * n + (n - 1), + conductance_matrix_accessor(i, n - 1) + 1.0 / R_line)); + } + // B matrix + ABCD_matrix.push_back(sparse_element(i * n, i * n + (m * n), + -conductance_matrix_accessor(i, 0))); + ABCD_matrix.push_back( + sparse_element(i * n + (n - 1), i * n + (n - 1) + (m * n), + -conductance_matrix_accessor(i, n - 1))); + // E matrix + if (E_matrix_accessor != NULL) { + if (R_source == 0) { + E_matrix_accessor[i * n] = V_WL_accessor[i]; + } else { + E_matrix_accessor[i * n] = V_WL_accessor[i] / R_source; + } + } +#pragma omp for nowait + for (int j = 1; j < n - 1; j++) { + // A matrix + if (R_line == 0) { + ABCD_matrix.push_back(sparse_element( + i * n + j, i * n + j, conductance_matrix_accessor(i, j))); + ABCD_matrix.push_back(sparse_element(i * n + j + 1, i * n + j, 0)); + ABCD_matrix.push_back(sparse_element(i * n + j, i * n + j + 1, 0)); + } else { + ABCD_matrix.push_back( + sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor(i, j) + 2.0f / R_line)); + ABCD_matrix.push_back( + sparse_element(i * n + j + 1, i * n + j, -1.0f / R_line)); + ABCD_matrix.push_back( + sparse_element(i * n + j, i * n + j + 1, -1.0f / R_line)); + } + // B matrix + ABCD_matrix.push_back(sparse_element(i * n + j, i * n + j + (m * n), + -conductance_matrix_accessor(i, j))); + } + } + // C, D, and E (partial) matrices +#pragma omp parallel for + for (int j = 0; j < n; j++) { + // D matrix + if (R_line == 0) { + ABCD_matrix.push_back(sparse_element(m * n + (j * m), m * n + j, + -conductance_matrix_accessor(0, j))); + ABCD_matrix.push_back(sparse_element(m * n + (j * m), m * n + j + n, 0)); + ABCD_matrix.push_back(sparse_element(m * n + (j * m) + m - 1, + m * n + (n * (m - 2)) + j, 0)); + } else { + ABCD_matrix.push_back( + sparse_element(m * n + (j * m), m * n + j, + -1.0f / R_line - conductance_matrix_accessor(0, j))); + ABCD_matrix.push_back( + sparse_element(m * n + (j * m), m * n + j + n, 1.0f / R_line)); + ABCD_matrix.push_back(sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 2)) + j, 1.0f / R_line)); + } + if (R_source == 0) { + ABCD_matrix.push_back(sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, + -conductance_matrix_accessor(m - 1, j) - 1.0f / R_line)); + } else if (R_line == 0) { + ABCD_matrix.push_back(sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, + -1.0f / R_source - conductance_matrix_accessor(m - 1, j))); + } else { + ABCD_matrix.push_back(sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, + -1.0f / R_source - conductance_matrix_accessor(m - 1, j) - + 1.0f / R_line)); + } + // C matrix + ABCD_matrix.push_back( + sparse_element(j * m + (m * n), j, conductance_matrix_accessor(0, j))); + ABCD_matrix.push_back( + sparse_element(j * m + (m - 1) + (m * n), n * (m - 1) + j, + conductance_matrix_accessor(m - 1, j))); + // E matrix + if (E_matrix_accessor != NULL) { + if (R_source == 0) { + E_matrix_accessor[m * n + (j + 1) * m - 1] = -V_BL_accessor[j]; + } else { + E_matrix_accessor[m * n + (j + 1) * m - 1] = + -V_BL_accessor[j] / R_source; + } + } +#pragma omp for nowait + for (int i = 1; i < m - 1; i++) { + // D matrix + if (R_line == 0) { + ABCD_matrix.push_back( + sparse_element(m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 0)); + ABCD_matrix.push_back( + sparse_element(m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 0)); + ABCD_matrix.push_back( + sparse_element(m * n + (j * m) + i, m * n + (n * i) + j, + -conductance_matrix_accessor(i, j))); + } else { + ABCD_matrix.push_back(sparse_element( + m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 1.0f / R_line)); + ABCD_matrix.push_back(sparse_element( + m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 1.0f / R_line)); + ABCD_matrix.push_back( + sparse_element(m * n + (j * m) + i, m * n + (n * i) + j, + -conductance_matrix_accessor(i, j) - 2.0f / R_line)); + } + // C matrix + ABCD_matrix.push_back(sparse_element(j * m + i + (m * n), n * i + j, + conductance_matrix_accessor(i, j))); + } + } + return ABCD_matrix; +} + +at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, + at::Tensor V_BL, float R_source, float R_line, + bool det_readout_currents) { + int m = conductance_matrix.sizes()[0]; + int n = conductance_matrix.sizes()[1]; + Eigen::Map> + conductance_matrix_accessor(conductance_matrix.data_ptr(), m, n, + Eigen::Stride<1, Eigen::Dynamic>(1, n)); + float *V_WL_accessor = V_WL.data_ptr(); + float *V_BL_accessor = V_BL.data_ptr(); + int non_zero_elements = 8 * m * n - 2 * m - 2 * n; + triplet_vector ABCD_matrix; + ABCD_matrix.reserve(non_zero_elements); + float *E_matrix_accessor = (float *)malloc(sizeof(float) * (2 * m * n)); +#pragma omp parallel for + for (int i = 0; i < (2 * m * n); i++) { + E_matrix_accessor[i] = 0; + } + ABCD_matrix = + gen_ABCDE(conductance_matrix_accessor, m, n, V_WL_accessor, V_BL_accessor, + R_source, R_line, ABCD_matrix, E_matrix_accessor); + Eigen::Map E_matrix(E_matrix_accessor, (2 * m * n)); + // Solve (ABCD)V = E + Eigen::SparseMatrix ABCD(2 * m * n, 2 * m * n); + ABCD.setFromTriplets(ABCD_matrix.begin(), ABCD_matrix.end()); + Eigen::SparseLU> solver; + solver.compute(ABCD); + Eigen::VectorXf V = solver.solve(E_matrix); + at::Tensor V_applied_tensor = at::zeros({m, n}); +#pragma omp parallel for + for (int i = 0; i < m; i++) { +#pragma omp for nowait + for (int j = 0; j < n; j++) { + V_applied_tensor.index_put_({i, j}, V[n * i + j] - V[m * n + n * i + j]); + } + } + if (!det_readout_currents) { + return V_applied_tensor; + } else { + return at::sum(at::mul(V_applied_tensor, conductance_matrix), 0); + } +} + +at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, + at::Tensor V_BL, float R_source, float R_line, + int n_input_batches) { + int m = conductance_matrix.sizes()[0]; + int n = conductance_matrix.sizes()[1]; + Eigen::Map> + conductance_matrix_accessor(conductance_matrix.data_ptr(), m, n, + Eigen::Stride<1, Eigen::Dynamic>(1, n)); + Eigen::Map> + V_WL_accessor(V_WL.data_ptr(), n_input_batches, m, + Eigen::Stride<1, Eigen::Dynamic>(1, m)); + Eigen::Map> + V_BL_accessor(V_BL.data_ptr(), n_input_batches, n, + Eigen::Stride<1, Eigen::Dynamic>(1, n)); + int non_zero_elements = 8 * m * n - 2 * m - 2 * n; + triplet_vector ABCD_matrix; + ABCD_matrix.reserve(non_zero_elements); + ABCD_matrix = gen_ABCDE(conductance_matrix_accessor, m, n, NULL, NULL, + R_source, R_line, ABCD_matrix, NULL); + // Solve (ABCD)V = E + Eigen::SparseMatrix ABCD(2 * m * n, 2 * m * n); + ABCD.setFromTriplets(ABCD_matrix.begin(), ABCD_matrix.end()); + Eigen::SparseLU> solver; + solver.compute(ABCD); + at::Tensor out = at::zeros({n_input_batches, n}); +#pragma omp parallel for + for (int i = 0; i < n_input_batches; i++) { + Eigen::VectorXf E_matrix = Eigen::VectorXf::Zero(2 * m * n); + for (int j = 0; j < m; j++) { + E_matrix(j * n) = V_WL_accessor(i, j) / R_source; + } + for (int k = 0; k < n; k++) { + E_matrix(m * n + (k + 1) * m - 1) = -V_BL_accessor(i, k) / R_source; + } + Eigen::VectorXf V = solver.solve(E_matrix); + at::Tensor V_applied_tensor = at::zeros({m, n}); +#pragma omp parallel for + for (int j = 0; j < m; j++) { +#pragma omp for nowait + for (int k = 0; k < n; k++) { + V_applied_tensor.index_put_({j, k}, V_applied_tensor.index({j, k}) + + V[n * j + k] - + V[m * n + n * j + k]); + } + } + out.index_put_({i, torch::indexing::Slice()}, + at::sum(at::mul(V_applied_tensor, conductance_matrix), 0)); + } + return out; +} + +void solve_passive_bindings(py::module_ &m) { + m.def( + "solve_passive", + [&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL, + float R_source, float R_line, bool det_readout_currents) { + return solve_passive(conductance_matrix, V_WL, V_BL, R_source, R_line, + det_readout_currents); + }, + py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"), + py::arg("R_source"), py::arg("R_line"), + py::arg("det_readout_currents") = true); + m.def( + "solve_passive", + [&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL, + float R_source, float R_line, int n_input_batches) { + return solve_passive(conductance_matrix, V_WL, V_BL, R_source, R_line, + n_input_batches); + }, + py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"), + py::arg("R_source"), py::arg("R_line"), py::arg("n_input_batches")); +} \ No newline at end of file diff --git a/memtorch/cpp/solve_passive.h b/memtorch/cpp/solve_passive.h new file mode 100644 index 00000000..f7610766 --- /dev/null +++ b/memtorch/cpp/solve_passive.h @@ -0,0 +1,11 @@ +#include +#include +#include + +void solve_passive_bindings(py::module_ &m); +at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, + at::Tensor V_BL, float R_source, float R_line, + bool det_readout_currents); +at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, + at::Tensor V_BL, float R_source, float R_line, + int n_input_batches); \ No newline at end of file diff --git a/memtorch/cpp/solve_sparse_linear.cpp b/memtorch/cpp/solve_sparse_linear.cpp new file mode 100644 index 00000000..cd421963 --- /dev/null +++ b/memtorch/cpp/solve_sparse_linear.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include + +void solve_sparse_linear(Eigen::SparseMatrix A, double *B_values, + int n) { + Eigen::SparseQR, Eigen::COLAMDOrdering> QR( + A); + QR.analyzePattern(A); + QR.factorize(A); + Eigen::Map B(B_values, n); + Eigen::VectorXd X = QR.solve(B); + memcpy(B_values, X.data(), sizeof(double) * n); +} + +void solve_sparse_linear(Eigen::SparseMatrix A, float *B_values, int n) { + Eigen::SparseQR, Eigen::COLAMDOrdering> QR(A); + QR.analyzePattern(A); + QR.factorize(A); + Eigen::Map B(B_values, n); + Eigen::VectorXf X = QR.solve(B); + memcpy(B_values, X.data(), sizeof(float) * n); +} \ No newline at end of file diff --git a/memtorch/cpp/solve_sparse_linear.h b/memtorch/cpp/solve_sparse_linear.h new file mode 100644 index 00000000..f35a05ee --- /dev/null +++ b/memtorch/cpp/solve_sparse_linear.h @@ -0,0 +1,7 @@ +#include +#include +#include + +void solve_sparse_linear(Eigen::SparseMatrix A, double *B_values, + int n); +void solve_sparse_linear(Eigen::SparseMatrix A, float *B_values, int n); \ No newline at end of file diff --git a/memtorch/cpp/tile_matmul.cpp b/memtorch/cpp/tile_matmul.cpp index 88e86608..c2b6964e 100644 --- a/memtorch/cpp/tile_matmul.cpp +++ b/memtorch/cpp/tile_matmul.cpp @@ -3,6 +3,8 @@ #include #include "quantize.h" +#include "solve_passive.h" + using namespace torch::indexing; at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, @@ -32,6 +34,37 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, return result; } +at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + int mat_a_shape[2], at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, int mat_b_shape[2], + float source_resistance, float line_resistance) { + int mat_a_rows = mat_a_tiles.sizes().end()[-2]; + c10::IntArrayRef mat_b_tiles_shape = mat_b_tiles.sizes(); + c10::IntArrayRef mat_b_tiles_map_shape = mat_b_tiles_map.sizes(); + at::Tensor partial_sum = + at::zeros({mat_b_tiles_map_shape[1], mat_b_tiles_shape.back()}); + at::Tensor result = at::zeros({mat_a_shape[0], mat_b_shape[1]}); +#pragma omp parallel for + for (int i = 0; i < mat_a_rows; i++) { + at::Tensor mat_a_row_tiles = mat_a_tiles.index({Slice(), i, Slice()}); + for (int j = 0; j < mat_b_tiles_map_shape[0]; j++) { + at::Tensor tile_a = mat_a_row_tiles[mat_a_tiles_map[j].item()]; + for (int k = 0; k < mat_b_tiles_map_shape[1]; k++) { + at::Tensor tile_b = mat_b_tiles[mat_b_tiles_map[j][k].item()]; + partial_sum[k] += + solve_passive(tile_b, tile_a, at::zeros({tile_b.sizes()[1]}), + source_resistance, line_resistance, true) + .squeeze(); + } + result.index_put_({i, Slice()}, result.index({i, Slice()}) + + partial_sum.flatten().index( + {Slice(0, mat_b_shape[1])})); + partial_sum = partial_sum.zero_(); + } + } + return result; +} + at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, int mat_a_shape[2], at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, int mat_b_shape[2], @@ -63,14 +96,48 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, return result; } +at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + int mat_a_shape[2], at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, int mat_b_shape[2], + float source_resistance, float line_resistance, + int ADC_resolution, float ADC_overflow_rate, + int quant_method) { + int mat_a_rows = mat_a_tiles.sizes().end()[-2]; + c10::IntArrayRef mat_b_tiles_shape = mat_b_tiles.sizes(); + c10::IntArrayRef mat_b_tiles_map_shape = mat_b_tiles_map.sizes(); + at::Tensor partial_sum = + at::zeros({mat_b_tiles_map_shape[1], mat_b_tiles_shape.back()}); + at::Tensor result = at::zeros({mat_a_shape[0], mat_b_shape[1]}); +#pragma omp parallel for + for (int i = 0; i < mat_a_rows; i++) { + at::Tensor mat_a_row_tiles = mat_a_tiles.index({Slice(), i, Slice()}); + for (int j = 0; j < mat_b_tiles_map_shape[0]; j++) { + partial_sum = + at::zeros({mat_b_tiles_map_shape[1], mat_b_tiles_shape.back()}); + at::Tensor tile_a = mat_a_row_tiles[mat_a_tiles_map[j].item()]; + for (int k = 0; k < mat_b_tiles_map_shape[1]; k++) { + at::Tensor tile_b = mat_b_tiles[mat_b_tiles_map[j][k].item()]; + at::Tensor result = + solve_passive(tile_b, tile_a, at::zeros({tile_b.sizes()[1]}), + source_resistance, line_resistance, true) + .squeeze(); + quantize(result, ADC_resolution, ADC_overflow_rate, quant_method); + partial_sum[k] += result; + } + partial_sum = partial_sum.flatten().index({Slice(0, mat_b_shape[1])}); + result.index_put_({i, Slice()}, result.index({i, Slice()}) + partial_sum); + } + } + return result; +} + void tile_matmul_bindings(py::module_ &m) { - // Binding without quantization support + // Binding without quantization support (transistor=True) m.def( "tile_matmul", [](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, std::tuple mat_a_shape, at::Tensor mat_b_tiles, - at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, - int cuda_malloc_heap_size) { + at::Tensor mat_b_tiles_map, std::tuple mat_b_shape) { assert((std::tuple_size(mat_a_shape) == 2) && (std::tuple_size(mat_b_shape) == 2)); int mat_a_shape_array[2] = {(int)std::get<0>(mat_a_shape), @@ -82,16 +149,35 @@ void tile_matmul_bindings(py::module_ &m) { }, py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), py::arg("mat_a_shape"), py::arg("mat_b_tiles"), + py::arg("mat_b_tiles_map"), py::arg("mat_b_shape")); + // Binding without quantization support (transistor=False) + m.def( + "tile_matmul", + [](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + std::tuple mat_a_shape, at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, + float source_resistance, float line_resistance) { + assert((std::tuple_size(mat_a_shape) == 2) && + (std::tuple_size(mat_b_shape) == 2)); + int mat_a_shape_array[2] = {(int)std::get<0>(mat_a_shape), + (int)std::get<1>(mat_a_shape)}; + int mat_b_shape_array[2] = {(int)std::get<0>(mat_b_shape), + (int)std::get<1>(mat_b_shape)}; + return tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape_array, + mat_b_tiles, mat_b_tiles_map, mat_b_shape_array, + source_resistance, line_resistance); + }, + py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), + py::arg("mat_a_shape"), py::arg("mat_b_tiles"), py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), - py::arg("cuda_malloc_heap_size") = NULL); - // Binding with quantization support + py::arg("source_resistance"), py::arg("line_resistance")); + // Binding with quantization support (transistor=True) m.def( "tile_matmul", [](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, std::tuple mat_a_shape, at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, - int ADC_resolution, float ADC_overflow_rate, int quant_method, - int cuda_malloc_heap_size) { + int ADC_resolution, float ADC_overflow_rate, int quant_method) { assert((std::tuple_size(mat_a_shape) == 2) && (std::tuple_size(mat_b_shape) == 2)); int mat_a_shape_array[2] = {(int)std::get<0>(mat_a_shape), @@ -106,5 +192,30 @@ void tile_matmul_bindings(py::module_ &m) { py::arg("mat_a_shape"), py::arg("mat_b_tiles"), py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), py::arg("ADC_resolution"), py::arg("ADC_overflow_rate"), - py::arg("quant_method"), py::arg("cuda_malloc_heap_size") = NULL); + py::arg("quant_method")); + // Binding with quantization support (transistor=False) + m.def( + "tile_matmul", + [](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + std::tuple mat_a_shape, at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, + float source_resistance, float line_resistance, int ADC_resolution, + float ADC_overflow_rate, int quant_method) { + assert((std::tuple_size(mat_a_shape) == 2) && + (std::tuple_size(mat_b_shape) == 2)); + int mat_a_shape_array[2] = {(int)std::get<0>(mat_a_shape), + (int)std::get<1>(mat_a_shape)}; + int mat_b_shape_array[2] = {(int)std::get<0>(mat_b_shape), + (int)std::get<1>(mat_b_shape)}; + return tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape_array, + mat_b_tiles, mat_b_tiles_map, mat_b_shape_array, + source_resistance, line_resistance, ADC_resolution, + ADC_overflow_rate, quant_method); + }, + py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), + py::arg("mat_a_shape"), py::arg("mat_b_tiles"), + py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), + py::arg("source_resistance"), py::arg("line_resistance"), + py::arg("ADC_resolution"), py::arg("ADC_overflow_rate"), + py::arg("quant_method")); } \ No newline at end of file diff --git a/memtorch/cpp/tile_matmul.h b/memtorch/cpp/tile_matmul.h index 76660755..f1d09999 100644 --- a/memtorch/cpp/tile_matmul.h +++ b/memtorch/cpp/tile_matmul.h @@ -5,5 +5,15 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, int mat_a_shape[2], at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, int mat_b_shape[2], + float source_resistance, float line_resistance); +at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + int mat_a_shape[2], at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, int mat_b_shape[2], + int ADC_resolution, float ADC_overflow_rate, + int quant_method); +at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + int mat_a_shape[2], at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, int mat_b_shape[2], + float source_resistance, float line_resistance, int ADC_resolution, float ADC_overflow_rate, int quant_method); \ No newline at end of file diff --git a/memtorch/cu/bindings.cpp b/memtorch/cu/bindings.cpp index 20fbe44c..96a691eb 100644 --- a/memtorch/cu/bindings.cpp +++ b/memtorch/cu/bindings.cpp @@ -4,14 +4,13 @@ #include "gen_tiles.h" #include "inference.h" +#include "solve_passive.h" #include "tile_matmul.h" -void tile_matmul_bindings(py::module_ &); -void gen_tiles_bindings_gpu(py::module_ &); -void inference_bindings(py::module_ &); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { gen_tiles_bindings_gpu(m); tile_matmul_bindings(m); inference_bindings(m); + solve_passive_bindings(m); } \ No newline at end of file diff --git a/memtorch/cu/inference.cpp b/memtorch/cu/inference.cpp index 8ee43420..49ee6588 100644 --- a/memtorch/cu/inference.cpp +++ b/memtorch/cu/inference.cpp @@ -11,16 +11,46 @@ at::Tensor tiled_inference(at::Tensor input, int input_shape[2], int cuda_malloc_heap_size) { at::Tensor input_tiles; at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCUDA, 0)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, NULL, NULL, -1, -1, -1, + cuda_malloc_heap_size); +} + +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + float source_resistance, float line_resistance, + int cuda_malloc_heap_size) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; std::tie(input_tiles, input_tiles_map) = gen_tiles( input, tile_shape, true, torch::TensorOptions().device(torch::kCUDA, 0)); return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, weight_tiles_map, weight_shape, NULL, NULL, -1, + source_resistance, line_resistance, cuda_malloc_heap_size); +} + +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + int ADC_resolution, float ADC_overflow_rate, + int quant_method, int cuda_malloc_heap_size) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCUDA, 0)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, ADC_resolution, + ADC_overflow_rate, quant_method, -1, -1, cuda_malloc_heap_size); } at::Tensor tiled_inference(at::Tensor input, int input_shape[2], int tile_shape[2], at::Tensor weight_tiles, at::Tensor weight_tiles_map, int weight_shape[2], + float source_resistance, float line_resistance, int ADC_resolution, float ADC_overflow_rate, int quant_method, int cuda_malloc_heap_size) { at::Tensor input_tiles; @@ -29,11 +59,12 @@ at::Tensor tiled_inference(at::Tensor input, int input_shape[2], input, tile_shape, true, torch::TensorOptions().device(torch::kCUDA, 0)); return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, weight_tiles_map, weight_shape, ADC_resolution, - ADC_overflow_rate, quant_method, cuda_malloc_heap_size); + ADC_overflow_rate, quant_method, source_resistance, + line_resistance, cuda_malloc_heap_size); } void inference_bindings(py::module_ &m) { - // Binding without quantization support + // Binding without quantization support (transistor=True) m.def( "tiled_inference", [](at::Tensor input, std::tuple input_shape, @@ -56,7 +87,33 @@ void inference_bindings(py::module_ &m) { py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), py::arg("weight_tiles"), py::arg("weight_tiles_map"), py::arg("weight_shape"), py::arg("cuda_malloc_heap_size") = 50); - // Binding with quantization support + // Binding without quantization support (transistor=False) + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + float source_resistance, float line_resistance, + int cuda_malloc_heap_size) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference(input, input_shape_array, tile_shape_array, + weight_tiles, weight_tiles_map, + weight_shape_array, source_resistance, + line_resistance, cuda_malloc_heap_size); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("source_resistance"), + py::arg("line_resistance"), py::arg("cuda_malloc_heap_size") = 50); + // Binding with quantization support (transistor=True) m.def( "tiled_inference", [](at::Tensor input, std::tuple input_shape, @@ -83,4 +140,33 @@ void inference_bindings(py::module_ &m) { py::arg("weight_shape"), py::arg("ADC_resolution"), py::arg("ADC_overflow_rate"), py::arg("quant_method"), py::arg("cuda_malloc_heap_size") = 50); + // Binding with quantization support (transistor=False) + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + float source_resistance, float line_resistance, int ADC_resolution, + float ADC_overflow_rate, int quant_method, int cuda_malloc_heap_size) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference( + input, input_shape_array, tile_shape_array, weight_tiles, + weight_tiles_map, weight_shape_array, source_resistance, + line_resistance, ADC_resolution, ADC_overflow_rate, quant_method, + cuda_malloc_heap_size); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("source_resistance"), + py::arg("line_resistance"), py::arg("ADC_resolution"), + py::arg("ADC_overflow_rate"), py::arg("quant_method"), + py::arg("cuda_malloc_heap_size") = 50); } \ No newline at end of file diff --git a/memtorch/cu/solve_passive.cpp b/memtorch/cu/solve_passive.cpp new file mode 100644 index 00000000..e7deb505 --- /dev/null +++ b/memtorch/cu/solve_passive.cpp @@ -0,0 +1,24 @@ +#include +#include +#include + +#include + +#include + +#include + +#include "solve_passive_kernels.cuh" + +void solve_passive_bindings(py::module_ &m) { + m.def( + "solve_passive", + [&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL, + float R_source, float R_line, bool det_readout_currents) { + return solve_passive(conductance_matrix, V_WL, V_BL, R_source, R_line, + det_readout_currents); + }, + py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"), + py::arg("R_source"), py::arg("R_line"), + py::arg("det_readout_currents") = true); +} \ No newline at end of file diff --git a/memtorch/cu/solve_passive.cuh b/memtorch/cu/solve_passive.cuh new file mode 100644 index 00000000..3819a4eb --- /dev/null +++ b/memtorch/cu/solve_passive.cuh @@ -0,0 +1,229 @@ +__device__ double cumsum(int *p, int *c, int n) { + int nz = 0; + double nz2 = 0; + for (int i = 0; i < n; i++) { + p[i] = nz; + nz += c[i]; + nz2 += c[i]; + c[i] = p[i]; + } + p[n] = nz; + return nz2; +} + +__device__ void cs_compress(int nst, int m, int n, int *ist, int *jst, + double *ast, int *ist_compressed, + int *jst_compressed, double *ast_compressed) { + int p; + int *w = (int *)malloc(sizeof(int) * n); + for (int i = 0; i < n; i++) { + w[i] = 0; + } + for (int k = 0; k < nst; k++) { + w[jst[k]]++; + } + cumsum(jst_compressed, w, n); + for (int k = 0; k < nst; k++) { + p = w[jst[k]]++; + ist_compressed[p] = ist[k]; + ast_compressed[p] = ast[k]; + } + free(w); + return; +} + +__device__ void +construct_ABCD_E(Eigen::MatrixXf conductance_matrix, Eigen::VectorXf V_WL, + Eigen::VectorXf V_BL, float R_source, float R_line, + int *ABCD_matrix_indices_x, int *ABCD_matrix_indices_y, + double *ABCD_matrix_values, int *ABCD_matrix_compressed_rows, + int *ABCD_matrix_compressed_columns, + double *ABCD_matirx_compressed_values, double *E_matrix) { + int m = conductance_matrix.rows(); + int n = conductance_matrix.cols(); + // A, B, and E (partial) matrices + int nonzero_elements = 8 * m * n - 2 * m - 2 * n; + int index = 0; + for (int i = 0; i < m; i++) { + // A matrix + ABCD_matrix_indices_x[index] = i * n; + ABCD_matrix_indices_y[index] = i * n; + if (R_source == 0) { + ABCD_matrix_values[index] = + (double)conductance_matrix(i, 0) + 1.0 / (double)R_line; + } else if (R_line == 0) { + ABCD_matrix_values[index] = + (double)conductance_matrix(i, 0) + 1.0 / (double)R_source; + } else { + ABCD_matrix_values[index] = (double)conductance_matrix(i, 0) + + 1.0 / (double)R_source + 1.0 / (double)R_line; + } + index++; + ABCD_matrix_indices_x[index] = i * n + 1; + ABCD_matrix_indices_y[index] = i * n; + if (R_line == 0) { + ABCD_matrix_values[index] = 0; + index++; + ABCD_matrix_indices_x[index] = i * n; + ABCD_matrix_indices_y[index] = i * n + 1; + ABCD_matrix_values[index] = 0; + index++; + ABCD_matrix_indices_x[index] = i * n + (n - 1); + ABCD_matrix_indices_y[index] = i * n + (n - 1); + ABCD_matrix_values[index] = (double)conductance_matrix(i, n - 1); + } else { + ABCD_matrix_values[index] = -1.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = i * n; + ABCD_matrix_indices_y[index] = i * n + 1; + ABCD_matrix_values[index] = -1.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = i * n + (n - 1); + ABCD_matrix_indices_y[index] = i * n + (n - 1); + ABCD_matrix_values[index] = + (double)conductance_matrix(i, n - 1) + 1.0 / (double)R_line; + } + index++; + // B matrix + ABCD_matrix_indices_x[index] = i * n; + ABCD_matrix_indices_y[index] = i * n + (m * n); + ABCD_matrix_values[index] = (double)-conductance_matrix(i, 0); + index++; + ABCD_matrix_indices_x[index] = i * n + (n - 1); + ABCD_matrix_indices_y[index] = i * n + (n - 1) + (m * n); + ABCD_matrix_values[index] = (double)-conductance_matrix(i, n - 1); + index++; + // E matrix + if (R_source == 0) { + E_matrix[i * n] = (double)V_WL[i]; + } else { + E_matrix[i * n] = (double)V_WL[i] / (double)R_source; + } + for (int j = 1; j < n - 1; j++) { + // A matrix + ABCD_matrix_indices_x[index] = i * n + j; + ABCD_matrix_indices_y[index] = i * n + j; + if (R_line == 0) { + ABCD_matrix_values[index] = (double)conductance_matrix(i, j); + index++; + ABCD_matrix_indices_x[index] = i * n + j + 1; + ABCD_matrix_indices_y[index] = i * n + j; + ABCD_matrix_values[index] = 0; + index++; + ABCD_matrix_indices_x[index] = i * n + j; + ABCD_matrix_indices_y[index] = i * n + j + 1; + ABCD_matrix_values[index] = 0; + } else { + ABCD_matrix_values[index] = + (double)conductance_matrix(i, j) + 2.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = i * n + j + 1; + ABCD_matrix_indices_y[index] = i * n + j; + ABCD_matrix_values[index] = -1.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = i * n + j; + ABCD_matrix_indices_y[index] = i * n + j + 1; + ABCD_matrix_values[index] = -1.0 / (double)R_line; + } + index++; + // B matrix + ABCD_matrix_indices_x[index] = i * n + j; + ABCD_matrix_indices_y[index] = i * n + j + (m * n); + ABCD_matrix_values[index] = (double)-conductance_matrix(i, j); + index++; + } + } + // C, D, and E (partial) matrices + for (int j = 0; j < n; j++) { + // D matrix + ABCD_matrix_indices_x[index] = m * n + (j * m); + ABCD_matrix_indices_y[index] = m * n + j; + ABCD_matrix_values[index] = + -1.0 / (double)R_line - conductance_matrix(0, j); + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m); + ABCD_matrix_indices_y[index] = m * n + j + n; + if (R_line == 0) { + ABCD_matrix_values[index] = 0; + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + m - 1; + ABCD_matrix_indices_y[index] = m * n + (n * (m - 2)) + j; + ABCD_matrix_values[index] = 0; + } else { + ABCD_matrix_values[index] = 1.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + m - 1; + ABCD_matrix_indices_y[index] = m * n + (n * (m - 2)) + j; + ABCD_matrix_values[index] = 1.0 / (double)R_line; + } + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + m - 1; + ABCD_matrix_indices_y[index] = m * n + (n * (m - 1)) + j; + if (R_source == 0) { + ABCD_matrix_values[index] = + -conductance_matrix(m - 1, j) - 1.0 / (double)R_line; + } else if (R_line == 0) { + ABCD_matrix_values[index] = + -1.0 / (double)R_source - conductance_matrix(m - 1, j); + } else { + ABCD_matrix_values[index] = -1.0 / (double)R_source - + conductance_matrix(m - 1, j) - + 1.0 / (double)R_line; + } + index++; + // C matrix + ABCD_matrix_indices_x[index] = j * m + (m * n); + ABCD_matrix_indices_y[index] = j; + ABCD_matrix_values[index] = (double)conductance_matrix(0, j); + index++; + ABCD_matrix_indices_x[index] = j * m + (m - 1) + (m * n); + ABCD_matrix_indices_y[index] = n * (m - 1) + j; + ABCD_matrix_values[index] = (double)conductance_matrix(m - 1, j); + index++; + // E matrix + if (R_source == 0) { + E_matrix[m * n + (j + 1) * m - 1] = -V_BL[j]; + } else { + E_matrix[m * n + (j + 1) * m - 1] = -V_BL[j] / R_source; + } + for (int i = 1; i < m - 1; i++) { + // D matrix + ABCD_matrix_indices_x[index] = m * n + (j * m) + i; + ABCD_matrix_indices_y[index] = m * n + (n * (i - 1)) + j; + if (R_line == 0) { + ABCD_matrix_values[index] = 0; + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + i; + ABCD_matrix_indices_y[index] = m * n + (n * (i + 1)) + j; + ABCD_matrix_values[index] = 0; + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + i; + ABCD_matrix_indices_y[index] = m * n + (n * i) + j; + ABCD_matrix_values[index] = (double)-conductance_matrix(i, j); + } else { + ABCD_matrix_values[index] = 1.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + i; + ABCD_matrix_indices_y[index] = m * n + (n * (i + 1)) + j; + ABCD_matrix_values[index] = 1.0 / (double)R_line; + index++; + ABCD_matrix_indices_x[index] = m * n + (j * m) + i; + ABCD_matrix_indices_y[index] = m * n + (n * i) + j; + ABCD_matrix_values[index] = + (double)-conductance_matrix(i, j) - 2.0 / (double)R_line; + } + index++; + // C matrix + ABCD_matrix_indices_x[index] = j * m + i + (m * n); + ABCD_matrix_indices_y[index] = n * i + j; + ABCD_matrix_values[index] = (double)conductance_matrix(i, j); + index++; + } + } + V_WL.resize(0, 0); + V_BL.resize(0, 0); + cs_compress(nonzero_elements, 2 * n * m, 2 * n * m, ABCD_matrix_indices_x, + ABCD_matrix_indices_y, ABCD_matrix_values, + ABCD_matrix_compressed_rows, ABCD_matrix_compressed_columns, + ABCD_matirx_compressed_values); +} \ No newline at end of file diff --git a/memtorch/cu/solve_passive.h b/memtorch/cu/solve_passive.h new file mode 100644 index 00000000..9c9689de --- /dev/null +++ b/memtorch/cu/solve_passive.h @@ -0,0 +1 @@ +void solve_passive_bindings(py::module_ &m); \ No newline at end of file diff --git a/memtorch/cu/solve_passive_kernels.cu b/memtorch/cu/solve_passive_kernels.cu new file mode 100644 index 00000000..56177564 --- /dev/null +++ b/memtorch/cu/solve_passive_kernels.cu @@ -0,0 +1,284 @@ +#include "cuda_runtime.h" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "solve_passive.h" +#include "solve_sparse_linear.h" +#include "utils.cuh" + +class Triplet { +public: + __host__ __device__ Triplet() : m_row(0), m_col(0), m_value(0) {} + + __host__ __device__ Triplet(int i, int j, float v) + : m_row(i), m_col(j), m_value(v) {} + + __host__ __device__ const int &row() { return m_row; } + __host__ __device__ const int &col() { return m_col; } + __host__ __device__ const float &value() { return m_value; } + +protected: + int m_row, m_col; + float m_value; +}; + +typedef Triplet sparse_element; + +__global__ void gen_ABE_kernel( + torch::PackedTensorAccessor32 conductance_matrix_accessor, + float *V_WL_accessor, float *V_BL_accessor, int m, int n, float R_source, + float R_line, sparse_element *ABCD_matrix, float *E_matrix) { + int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < m; i++) + int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < n; j++) + if (i < m && j < n) { + int index = (i * n + j) * 5; + // A matrix + if (j == 0) { + // E matrix (partial) + if (E_matrix != NULL) { + if (R_source == 0) { + E_matrix[i * n] = V_WL_accessor[i]; + } else { + E_matrix[i * n] = V_WL_accessor[i] / R_source; + } + } + if (R_source == 0) { + ABCD_matrix[index] = sparse_element( + i * n, i * n, conductance_matrix_accessor[i][0] + 1.0f / R_line); + } else if (R_line == 0) { + ABCD_matrix[index] = sparse_element( + i * n, i * n, conductance_matrix_accessor[i][0] + 1.0f / R_source); + } else { + ABCD_matrix[index] = + sparse_element(i * n, i * n, + conductance_matrix_accessor[i][0] + 1.0f / R_source + + 1.0f / R_line); + } + } else { + ABCD_matrix[index] = sparse_element(0, 0, 0.0f); + } + index++; + if (R_line == 0) { + ABCD_matrix[index] = sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor[i][j]); + } else { + ABCD_matrix[index] = + sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor[i][j] + 2.0f / R_line); + } + index++; + if (j < n - 1) { + if (R_line == 0) { + ABCD_matrix[index] = sparse_element(i * n + j + 1, i * n + j, 0); + index++; + ABCD_matrix[index] = sparse_element(i * n + j, i * n + j + 1, 0); + } else { + ABCD_matrix[index] = + sparse_element(i * n + j + 1, i * n + j, -1.0f / R_line); + index++; + ABCD_matrix[index] = + sparse_element(i * n + j, i * n + j + 1, -1.0f / R_line); + } + } else { + if (R_line == 0) { + ABCD_matrix[index] = sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor[i][j]); + } else { + ABCD_matrix[index] = + sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor[i][j] + 1.0 / R_line); + } + index++; + ABCD_matrix[index] = sparse_element(0, 0, 0.0f); + } + index++; + // B matrix + ABCD_matrix[index] = sparse_element(i * n + j, i * n + j + (m * n), + -conductance_matrix_accessor[i][j]); + } +} + +__global__ void gen_CDE_kernel( + torch::PackedTensorAccessor32 conductance_matrix_accessor, + float *V_WL_accessor, float *V_BL_accessor, int m, int n, float R_source, + float R_line, sparse_element *ABCD_matrix, float *E_matrix) { + int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < n; j++) + int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < m; i++) + if (j < n && i < m) { + int index = (5 * m * n) + ((j * m + i) * 4); + // D matrix + if (i == 0) { + // E matrix (partial) + if (E_matrix != NULL) { + if (R_source == 0) { + E_matrix[m * n + (j + 1) * m - 1] = -V_BL_accessor[j]; + } else { + E_matrix[m * n + (j + 1) * m - 1] = -V_BL_accessor[j] / R_source; + } + } + if (R_line == 0) { + ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j, + -conductance_matrix_accessor[0][j]); + index++; + ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j + n, 0); + } else { + ABCD_matrix[index] = + sparse_element(m * n + (j * m), m * n + j, + -1.0f / R_line - conductance_matrix_accessor[0][j]); + index++; + ABCD_matrix[index] = + sparse_element(m * n + (j * m), m * n + j + n, 1.0f / R_line); + } + index++; + ABCD_matrix[index] = sparse_element(0, 0, 0.0f); + } else if (i < m - 1) { + if (R_line == 0) { + ABCD_matrix[index] = + sparse_element(m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 0); + index++; + ABCD_matrix[index] = + sparse_element(m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 0); + index++; + ABCD_matrix[index] = + sparse_element(m * n + (j * m) + i, m * n + (n * i) + j, + -conductance_matrix_accessor[i][j]); + } else { + ABCD_matrix[index] = sparse_element( + m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 1.0f / R_line); + index++; + ABCD_matrix[index] = sparse_element( + m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 1.0f / R_line); + index++; + ABCD_matrix[index] = + sparse_element(m * n + (j * m) + i, m * n + (n * i) + j, + -conductance_matrix_accessor[i][j] - 2.0f / R_line); + } + } else { + if (R_line == 0) { + ABCD_matrix[index] = sparse_element(m * n + (j * m) + m - 1, + m * n + (n * (m - 2)) + j, 0); + } else { + ABCD_matrix[index] = sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 2)) + j, 1 / R_line); + } + index++; + if (R_source == 0) { + ABCD_matrix[index] = sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, + -conductance_matrix_accessor[m - 1][j] - 1.0f / R_line); + } else if (R_line == 0) { + ABCD_matrix[index] = sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, + -1.0f / R_source - conductance_matrix_accessor[m - 1][j]); + } else { + ABCD_matrix[index] = sparse_element( + m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, + -1.0f / R_source - conductance_matrix_accessor[m - 1][j] - + 1.0f / R_line); + } + index++; + ABCD_matrix[index] = sparse_element(0, 0, 0.0f); + } + index++; + // C matrix + ABCD_matrix[index] = sparse_element(j * m + i + (m * n), n * i + j, + conductance_matrix_accessor[i][j]); + } +} + +__global__ void +construct_V_applied(torch::PackedTensorAccessor32 V_applied_accessor, + float *V_accessor, int m, int n) { + int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < m; i++) + int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < n; j++) + if (i < m && j < n) { + V_applied_accessor[i][j] = + V_accessor[n * i + j] - V_accessor[m * n + n * i + j]; + } +} + +at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, + at::Tensor V_BL, float R_source, float R_line, + bool det_readout_currents) { + assert(at::cuda::is_available()); + conductance_matrix = conductance_matrix.to(torch::Device("cuda:0")); + V_WL = V_WL.to(torch::Device("cuda:0")); + V_BL = V_BL.to(torch::Device("cuda:0")); + int m = conductance_matrix.sizes()[0]; + int n = conductance_matrix.sizes()[1]; + torch::PackedTensorAccessor32 conductance_matrix_accessor = + conductance_matrix.packed_accessor32(); + float *V_WL_accessor = V_WL.data_ptr(); + float *V_BL_accessor = V_BL.data_ptr(); + int non_zero_elements = + (5 * m * n) + + (4 * m * n); // Uncompressed (with padding for CUDA execution). + // When compressed, contains 8 * m * n - 2 * m - 2 * n unique values. + sparse_element *ABCD_matrix; + sparse_element *ABCD_matrix_host = + (sparse_element *)malloc(sizeof(sparse_element) * non_zero_elements); + cudaMalloc(&ABCD_matrix, sizeof(sparse_element) * non_zero_elements); + float *E_matrix; + cudaMalloc(&E_matrix, sizeof(float) * 2 * m * n); + cudaMemset(E_matrix, 0, sizeof(float) * 2 * m * n); + float *E_matrix_host = (float *)malloc(sizeof(float) * 2 * m * n); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int max_threads = prop.maxThreadsDim[0]; + dim3 grid; + dim3 block; + if (m * n > max_threads) { + int n_grid = ceil_int_div(m * n, max_threads); + grid = dim3(n_grid, n_grid, 1); + block = dim3(ceil_int_div(m, n_grid), ceil_int_div(n, n_grid), 1); + } else { + grid = dim3(1, 1, 1); + block = dim3(m, n, 1); + } + gen_ABE_kernel<<>>(conductance_matrix_accessor, V_WL_accessor, + V_BL_accessor, m, n, R_source, R_line, + ABCD_matrix, E_matrix); + gen_CDE_kernel<<>>(conductance_matrix_accessor, V_WL_accessor, + V_BL_accessor, m, n, R_source, R_line, + ABCD_matrix, E_matrix); + cudaSafeCall(cudaDeviceSynchronize()); + Eigen::SparseMatrix ABCD(2 * m * n, 2 * m * n); + cudaMemcpy(ABCD_matrix_host, ABCD_matrix, + sizeof(sparse_element) * non_zero_elements, + cudaMemcpyDeviceToHost); + ABCD.setFromTriplets(&ABCD_matrix_host[0], + &ABCD_matrix_host[non_zero_elements]); + ABCD.makeCompressed(); + cudaMemcpy(E_matrix_host, E_matrix, sizeof(float) * 2 * m * n, + cudaMemcpyDeviceToHost); + Eigen::Map V(E_matrix_host, 2 * m * n); + at::Tensor V_applied_tensor = + at::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA, 0)); + torch::PackedTensorAccessor32 V_applied_accessor = + V_applied_tensor.packed_accessor32(); + float *V_accessor; + cudaMalloc(&V_accessor, sizeof(float) * V.size()); + cudaMemcpy(V_accessor, V.data(), sizeof(float) * V.size(), + cudaMemcpyHostToDevice); + construct_V_applied<<>>(V_applied_accessor, V_accessor, m, n); + cudaSafeCall(cudaDeviceSynchronize()); + cudaSafeCall(cudaFree(ABCD_matrix)); + cudaSafeCall(cudaFree(E_matrix)); + cudaSafeCall(cudaFree(V_accessor)); + cudaStreamSynchronize(at::cuda::getCurrentCUDAStream()); + if (!det_readout_currents) { + return V_applied_tensor; + } else { + return at::sum(at::mul(V_applied_tensor, conductance_matrix), 0); + } +} \ No newline at end of file diff --git a/memtorch/cu/solve_passive_kernels.cuh b/memtorch/cu/solve_passive_kernels.cuh new file mode 100644 index 00000000..eb829a25 --- /dev/null +++ b/memtorch/cu/solve_passive_kernels.cuh @@ -0,0 +1,3 @@ +at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, + at::Tensor V_BL, float R_source, float R_line, + bool det_readout_currents); \ No newline at end of file diff --git a/memtorch/cu/solve_sparse_linear.cpp b/memtorch/cu/solve_sparse_linear.cpp new file mode 100644 index 00000000..cd421963 --- /dev/null +++ b/memtorch/cu/solve_sparse_linear.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include + +void solve_sparse_linear(Eigen::SparseMatrix A, double *B_values, + int n) { + Eigen::SparseQR, Eigen::COLAMDOrdering> QR( + A); + QR.analyzePattern(A); + QR.factorize(A); + Eigen::Map B(B_values, n); + Eigen::VectorXd X = QR.solve(B); + memcpy(B_values, X.data(), sizeof(double) * n); +} + +void solve_sparse_linear(Eigen::SparseMatrix A, float *B_values, int n) { + Eigen::SparseQR, Eigen::COLAMDOrdering> QR(A); + QR.analyzePattern(A); + QR.factorize(A); + Eigen::Map B(B_values, n); + Eigen::VectorXf X = QR.solve(B); + memcpy(B_values, X.data(), sizeof(float) * n); +} \ No newline at end of file diff --git a/memtorch/cu/solve_sparse_linear.h b/memtorch/cu/solve_sparse_linear.h new file mode 100644 index 00000000..49a59725 --- /dev/null +++ b/memtorch/cu/solve_sparse_linear.h @@ -0,0 +1,6 @@ +#include +#include +#include + +void solve_sparse_linear(Eigen::SparseMatrix A, double *B_values, int n); +void solve_sparse_linear(Eigen::SparseMatrix A, float *B_values, int n); \ No newline at end of file diff --git a/memtorch/cu/tile_matmul.cpp b/memtorch/cu/tile_matmul.cpp index 6b6cf584..b39f85dd 100644 --- a/memtorch/cu/tile_matmul.cpp +++ b/memtorch/cu/tile_matmul.cpp @@ -8,11 +8,13 @@ #include "tile_matmul_kernels.cuh" void tile_matmul_bindings(py::module_ &m) { + // Bindings without quantization support m.def( "tile_matmul", [&](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, std::tuple mat_a_shape, at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, + float source_resistance, float line_resistance, int cuda_malloc_heap_size) { assert((std::tuple_size(mat_a_shape) == 2) && (std::tuple_size(mat_b_shape) == 2)); @@ -22,18 +24,73 @@ void tile_matmul_bindings(py::module_ &m) { (int)std::get<1>(mat_b_shape)}; return tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape_array, mat_b_tiles, mat_b_tiles_map, mat_b_shape_array, -1, - -1, -1, cuda_malloc_heap_size); + -1, -1, source_resistance, line_resistance, + cuda_malloc_heap_size); + }, + py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), + py::arg("mat_a_shape"), py::arg("mat_b_tiles"), + py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), + py::arg("source_resistance") = -1, py::arg("line_resistance") = -1, + py::arg("cuda_malloc_heap_size") = 50); + // Binding without quantization support (transistor=False) + m.def( + "tile_matmul", + [&](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + std::tuple mat_a_shape, at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, + float source_resistance, float line_resistance, + int cuda_malloc_heap_size) { + assert((std::tuple_size(mat_a_shape) == 2) && + (std::tuple_size(mat_b_shape) == 2)); + int mat_a_shape_array[2] = {(int)std::get<0>(mat_a_shape), + (int)std::get<1>(mat_a_shape)}; + int mat_b_shape_array[2] = {(int)std::get<0>(mat_b_shape), + (int)std::get<1>(mat_b_shape)}; + return tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape_array, + mat_b_tiles, mat_b_tiles_map, mat_b_shape_array, -1, + -1, -1, source_resistance, line_resistance, + cuda_malloc_heap_size); }, py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), py::arg("mat_a_shape"), py::arg("mat_b_tiles"), py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), + py::arg("source_resistance"), py::arg("line_resistance"), py::arg("cuda_malloc_heap_size") = 50); + // Binding with quantization support (transistor=True) + m.def( + "tile_matmul", + [&](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + std::tuple mat_a_shape, at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, + int ADC_resolution, float overflow_rate, int quant_method, + float source_resistance, float line_resistance, + int cuda_malloc_heap_size) { + assert((std::tuple_size(mat_a_shape) == 2) && + (std::tuple_size(mat_b_shape) == 2)); + int mat_a_shape_array[2] = {(int)std::get<0>(mat_a_shape), + (int)std::get<1>(mat_a_shape)}; + int mat_b_shape_array[2] = {(int)std::get<0>(mat_b_shape), + (int)std::get<1>(mat_b_shape)}; + return tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape_array, + mat_b_tiles, mat_b_tiles_map, mat_b_shape_array, + ADC_resolution, overflow_rate, quant_method, + source_resistance, line_resistance, + cuda_malloc_heap_size); + }, + py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), + py::arg("mat_a_shape"), py::arg("mat_b_tiles"), + py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), + py::arg("ADC_resolution"), py::arg("ADC_overflow_rate"), + py::arg("quant_method"), py::arg("source_resistance") = -1, + py::arg("line_resistance") = -1, py::arg("cuda_malloc_heap_size") = 50); + // Binding with quantization support (transistor=False) m.def( "tile_matmul", [&](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, std::tuple mat_a_shape, at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, std::tuple mat_b_shape, int ADC_resolution, float overflow_rate, int quant_method, + float source_resistance, float line_resistance, int cuda_malloc_heap_size) { assert((std::tuple_size(mat_a_shape) == 2) && (std::tuple_size(mat_b_shape) == 2)); @@ -44,11 +101,13 @@ void tile_matmul_bindings(py::module_ &m) { return tile_matmul(mat_a_tiles, mat_a_tiles_map, mat_a_shape_array, mat_b_tiles, mat_b_tiles_map, mat_b_shape_array, ADC_resolution, overflow_rate, quant_method, + source_resistance, line_resistance, cuda_malloc_heap_size); }, py::arg("mat_a_tiles"), py::arg("mat_a_tiles_map"), py::arg("mat_a_shape"), py::arg("mat_b_tiles"), py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), py::arg("ADC_resolution"), py::arg("ADC_overflow_rate"), - py::arg("quant_method"), py::arg("cuda_malloc_heap_size") = 50); + py::arg("quant_method"), py::arg("source_resistance"), + py::arg("line_resistance"), py::arg("cuda_malloc_heap_size") = 50); } \ No newline at end of file diff --git a/memtorch/cu/tile_matmul_kernels.cu b/memtorch/cu/tile_matmul_kernels.cu index 99bff258..a85e23d9 100644 --- a/memtorch/cu/tile_matmul_kernels.cu +++ b/memtorch/cu/tile_matmul_kernels.cu @@ -1,4 +1,5 @@ #include "cuda_runtime.h" +#include "utils.cuh" #include #include #include @@ -7,10 +8,12 @@ #include #include - -#include "utils.cuh" +#include +#include #include "quantize.cuh" +#include "solve_passive.cuh" +#include "solve_sparse_linear.h" __global__ void tile_matmul_kernel( float *mat_a_tiles_accessor, @@ -36,7 +39,6 @@ __global__ void tile_matmul_kernel( mat_b_tiles_shape[1], mat_b_tiles_shape[2], Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); Eigen::VectorXf partial_sum = (tile_a * tile_b).transpose(); -#pragma omp parallel for for (int ii = 0; ii < partial_sum.size(); ii++) { result[transform_2d_index(i, j * mat_b_tiles_shape[2] + ii, mat_b_shape_back)] += partial_sum[ii]; @@ -45,6 +47,85 @@ __global__ void tile_matmul_kernel( } } +__global__ void tile_matmul_kernel_A( + float *mat_a_tiles_accessor, + torch::PackedTensorAccessor32 mat_a_tiles_map_accessor, + int64_t *mat_a_tiles_shape, float *mat_b_tiles_accessor, + torch::PackedTensorAccessor32 mat_b_tiles_map_accessor, + int64_t *mat_b_tiles_shape, int mat_b_shape_back, + int *ABCD_matrix_indices_x, int *ABCD_matrix_indices_y, + double *ABCD_matrix_values, int *ABCD_matrix_compressed_rows, + int *ABCD_matrix_compressed_columns, double *ABCD_matrix_compressed_values, + double *E_matrix, float source_resistance, float line_resistance, + int limit_i, int limit_j, int limit_k) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + int j = threadIdx.y + blockIdx.y * blockDim.y; + int k = threadIdx.z + blockIdx.z * blockDim.z; + if (i < limit_i && j < limit_j && k < limit_k) { + Eigen::Map tile_a( + &mat_a_tiles_accessor[transform_3d_index(mat_a_tiles_map_accessor[k], i, + 0, mat_a_tiles_shape[1], + mat_a_tiles_shape[2])], + mat_a_tiles_shape[1]); + Eigen::Map> + tile_b(&mat_b_tiles_accessor[transform_3d_index( + mat_b_tiles_map_accessor[k][j], 0, 0, mat_b_tiles_shape[1], + mat_b_tiles_shape[2])], + mat_b_tiles_shape[1], mat_b_tiles_shape[2], + Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); + int m = (int)mat_b_tiles_shape[1]; + int n = (int)mat_b_tiles_shape[2]; + int nonzero_elements = 8 * m * n - 2 * m - 2 * n; + int kernel_index = transform_3d_index(i, j, k, limit_j, limit_k); + construct_ABCD_E( + tile_b, tile_a, Eigen::VectorXf::Zero(n), source_resistance, + line_resistance, + &ABCD_matrix_indices_x[kernel_index * nonzero_elements], + &ABCD_matrix_indices_y[kernel_index * nonzero_elements], + &ABCD_matrix_values[kernel_index * nonzero_elements], + &ABCD_matrix_compressed_rows[kernel_index * nonzero_elements], + &ABCD_matrix_compressed_columns[kernel_index * (2 * m * n)], + &ABCD_matrix_compressed_values[kernel_index * nonzero_elements], + &E_matrix[kernel_index * (2 * m * n)]); + } +} + +__global__ void tile_matmul_kernel_B( + double *E_matrix, float *mat_b_tiles_accessor, + torch::PackedTensorAccessor32 mat_b_tiles_map_accessor, + int64_t *mat_b_tiles_shape, int mat_b_shape_back, int m, int n, int limit_i, + int limit_j, int limit_k, float *result) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + int j = threadIdx.y + blockIdx.y * blockDim.y; + int k = threadIdx.z + blockIdx.z * blockDim.z; + if (i < limit_i && j < limit_j && k < limit_k) { + int kernel_index = transform_3d_index(i, j, k, limit_j, limit_k); + Eigen::Map> + tile_b(&mat_b_tiles_accessor[transform_3d_index( + mat_b_tiles_map_accessor[k][j], 0, 0, mat_b_tiles_shape[1], + mat_b_tiles_shape[2])], + mat_b_tiles_shape[1], mat_b_tiles_shape[2], + Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); + Eigen::MatrixXf I_applied_tensor = Eigen::MatrixXf::Zero(m, n); + for (int ii = 0; ii < m; ii++) { + for (int jj = 0; jj < n; jj++) { + I_applied_tensor(ii, jj) = + ((float)E_matrix[kernel_index * (2 * m * n) + n * ii + jj] - + (float) + E_matrix[kernel_index * (2 * m * n) + m * n + n * ii + jj]) * + tile_b(ii, jj); + } + } + Eigen::VectorXf I_tensor = I_applied_tensor.colwise().sum(); + for (int ii = 0; ii < n; ii++) { + result[transform_2d_index(i, j * mat_b_tiles_shape[2] + ii, + mat_b_shape_back)] += I_tensor[ii]; + } + } +} + __global__ void tile_matmul_kernel( float *mat_a_tiles_accessor, torch::PackedTensorAccessor32 mat_a_tiles_map_accessor, @@ -82,11 +163,49 @@ __global__ void tile_matmul_kernel( } } +__global__ void tile_matmul_kernel_B( + double *E_matrix, float *mat_b_tiles_accessor, + torch::PackedTensorAccessor32 mat_b_tiles_map_accessor, + int64_t *mat_b_tiles_shape, int mat_b_shape_back, int ADC_resolution, + float overflow_rate, int quant_method, int m, int n, int limit_i, + int limit_j, int limit_k, float *result) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + int j = threadIdx.y + blockIdx.y * blockDim.y; + int k = threadIdx.z + blockIdx.z * blockDim.z; + if (i < limit_i && j < limit_j && k < limit_k) { + int kernel_index = transform_3d_index(i, j, k, limit_j, limit_k); + Eigen::Map> + tile_b(&mat_b_tiles_accessor[transform_3d_index( + mat_b_tiles_map_accessor[k][j], 0, 0, mat_b_tiles_shape[1], + mat_b_tiles_shape[2])], + mat_b_tiles_shape[1], mat_b_tiles_shape[2], + Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); + Eigen::MatrixXf I_applied_tensor = Eigen::MatrixXf::Zero(m, n); + for (int ii = 0; ii < m; ii++) { + for (int jj = 0; jj < n; jj++) { + I_applied_tensor(ii, jj) = + ((float)E_matrix[kernel_index * (2 * m * n) + n * ii + jj] - + (float) + E_matrix[kernel_index * (2 * m * n) + m * n + n * ii + jj]) * + tile_b(ii, jj); + } + } + Eigen::VectorXf I_tensor = I_applied_tensor.colwise().sum(); + I_tensor = quantize(I_tensor, ADC_resolution, overflow_rate, quant_method); + for (int ii = 0; ii < n; ii++) { + result[transform_2d_index(i, j * mat_b_tiles_shape[2] + ii, + mat_b_shape_back)] += I_tensor[ii]; + } + } +} + at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, int mat_a_shape[2], at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, int mat_b_shape[2], int ADC_resolution, float overflow_rate, - int quant_method, int cuda_malloc_heap_size) { + int quant_method, float source_resistance, + float line_resistance, int cuda_malloc_heap_size) { assert(at::cuda::is_available()); mat_a_tiles = mat_a_tiles.to(torch::Device("cuda:0")); mat_a_tiles_map = mat_a_tiles_map.to(torch::Device("cuda:0")); @@ -121,12 +240,25 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, at::Tensor result = at::zeros({mat_a_shape[0], mat_b_shape[1]}, torch::device(torch::kCUDA)); cudaDeviceSetLimit(cudaLimitMallocHeapSize, - 1024 * 1024 * cuda_malloc_heap_size); + size_t(1024) * size_t(1024) * + size_t(cuda_malloc_heap_size)); + dim3 grid; + dim3 block; if (max_threads_dim[0] >= limit_i && max_threads_dim[1] >= limit_j && max_threads_dim[2] >= limit_k) { // If multiple blocks are not required - dim3 grid(limit_i, limit_j, limit_k); - dim3 block(1, 1, 1); + grid = {(unsigned int)limit_i, (unsigned int)limit_j, + (unsigned int)limit_k}; + block = {1, 1, 1}; + } else { + // If multiple blocks are required + grid = {(unsigned int)max_threads_dim[0], (unsigned int)max_threads_dim[1], + (unsigned int)max_threads_dim[2]}; + block = {(unsigned int)ceil_int_div(limit_i, max_threads_dim[0]), + (unsigned int)ceil_int_div(limit_j, max_threads_dim[1]), + (unsigned int)ceil_int_div(limit_k, max_threads_dim[2])}; + } + if (line_resistance == -1) { if (ADC_resolution == -1) { tile_matmul_kernel<<>>( mat_a_tiles_accessor, mat_a_tiles_map_accessor, mat_a_tiles_shape, @@ -140,23 +272,93 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, limit_j, limit_k, result.data_ptr()); } } else { - // If multiple blocks are required - dim3 grid(max_threads_dim[0], max_threads_dim[1], max_threads_dim[2]); - dim3 block(ceil_int_div(limit_i, max_threads_dim[0]), - ceil_int_div(limit_j, max_threads_dim[1]), - ceil_int_div(limit_k, max_threads_dim[2])); + int m = mat_b_tiles_shape_host[1]; + int n = mat_b_tiles_shape_host[2]; + int non_zero_elements = 8 * m * n - 2 * m - 2 * n; + int n_kernels = grid.x * block.x * grid.y * block.y * grid.z * block.z; + int *ABCD_matrix_indices_x; + int *ABCD_matrix_indices_y; + double *ABCD_matrix_values; + int *ABCD_matrix_compressed_columns; + int *ABCD_matrix_compressed_rows; + double *ABCD_matrix_compressed_values; + double *E_matrix; + cudaSafeCall(cudaMalloc(&ABCD_matrix_indices_x, + sizeof(int) * non_zero_elements * n_kernels)); + cudaSafeCall(cudaMalloc(&ABCD_matrix_indices_y, + sizeof(int) * non_zero_elements * n_kernels)); + cudaSafeCall(cudaMalloc(&ABCD_matrix_values, + sizeof(double) * non_zero_elements * n_kernels)); + cudaSafeCall(cudaMalloc(&ABCD_matrix_compressed_columns, + sizeof(int) * (2 * n * m) * n_kernels)); + cudaSafeCall(cudaMalloc(&ABCD_matrix_compressed_rows, + sizeof(int) * non_zero_elements * n_kernels)); + cudaSafeCall(cudaMalloc(&ABCD_matrix_compressed_values, + sizeof(double) * non_zero_elements * n_kernels)); + cudaSafeCall( + cudaMalloc(&E_matrix, sizeof(double) * (2 * m * n) * n_kernels)); + tile_matmul_kernel_A<<>>( + mat_a_tiles_accessor, mat_a_tiles_map_accessor, mat_a_tiles_shape, + mat_b_tiles_accessor, mat_b_tiles_map_accessor, mat_b_tiles_shape, + mat_b_shape[1], ABCD_matrix_indices_x, ABCD_matrix_indices_y, + ABCD_matrix_values, ABCD_matrix_compressed_rows, + ABCD_matrix_compressed_columns, ABCD_matrix_compressed_values, E_matrix, + source_resistance, line_resistance, limit_i, limit_j, limit_k); + cudaSafeCall(cudaDeviceSynchronize()); + cudaSafeCall(cudaFree(ABCD_matrix_indices_x)); + cudaSafeCall(cudaFree(ABCD_matrix_indices_y)); + cudaSafeCall(cudaFree(ABCD_matrix_values)); + int *ABCD_matrix_compressed_rows_host = + (int *)malloc(sizeof(int) * non_zero_elements); + int *ABCD_matrix_compressed_columns_host = + (int *)malloc(sizeof(int) * (2 * m * n)); + double *ABCD_matirx_compressed_values_host = + (double *)malloc(sizeof(double) * non_zero_elements); + double *E_matrix_host = + (double *)malloc(sizeof(double) * (2 * m * n) * n_kernels); + cudaSafeCall(cudaMemcpy(E_matrix_host, E_matrix, + sizeof(double) * (2 * m * n) * n_kernels, + cudaMemcpyDeviceToHost)); +#pragma omp parallel for + for (int i = 0; i < n_kernels; i++) { + cudaSafeCall( + cudaMemcpy(ABCD_matrix_compressed_rows_host, + &ABCD_matrix_compressed_rows[i * non_zero_elements], + sizeof(int) * non_zero_elements, cudaMemcpyDeviceToHost)); + cudaSafeCall(cudaMemcpy(ABCD_matrix_compressed_columns_host, + &ABCD_matrix_compressed_columns[i * (2 * n * m)], + sizeof(int) * (2 * m * n), + cudaMemcpyDeviceToHost)); + cudaSafeCall(cudaMemcpy( + ABCD_matirx_compressed_values_host, + &ABCD_matrix_compressed_values[i * non_zero_elements], + sizeof(double) * non_zero_elements, cudaMemcpyDeviceToHost)); + Eigen::Map> A( + (2 * m * n), (2 * m * n), non_zero_elements, + ABCD_matrix_compressed_columns_host, ABCD_matrix_compressed_rows_host, + ABCD_matirx_compressed_values_host); + solve_sparse_linear(A, &E_matrix_host[i * (2 * n * m)], 2 * m * n); + } + free(ABCD_matrix_compressed_rows_host); + free(ABCD_matrix_compressed_columns_host); + free(ABCD_matirx_compressed_values_host); + cudaSafeCall(cudaMemcpy(E_matrix, E_matrix_host, + sizeof(double) * (2 * n * m) * n_kernels, + cudaMemcpyHostToDevice)); + free(E_matrix_host); if (ADC_resolution == -1) { - tile_matmul_kernel<<>>( - mat_a_tiles_accessor, mat_a_tiles_map_accessor, mat_a_tiles_shape, - mat_b_tiles_accessor, mat_b_tiles_map_accessor, mat_b_tiles_shape, - mat_b_shape[1], limit_i, limit_j, limit_k, result.data_ptr()); + tile_matmul_kernel_B<<>>( + E_matrix, mat_b_tiles_accessor, mat_b_tiles_map_accessor, + mat_b_tiles_shape, mat_b_shape[1], m, n, limit_i, limit_j, limit_k, + result.data_ptr()); } else { - tile_matmul_kernel<<>>( - mat_a_tiles_accessor, mat_a_tiles_map_accessor, mat_a_tiles_shape, - mat_b_tiles_accessor, mat_b_tiles_map_accessor, mat_b_tiles_shape, - mat_b_shape[1], ADC_resolution, overflow_rate, quant_method, limit_i, - limit_j, limit_k, result.data_ptr()); + tile_matmul_kernel_B<<>>( + E_matrix, mat_b_tiles_accessor, mat_b_tiles_map_accessor, + mat_b_tiles_shape, mat_b_shape[1], ADC_resolution, overflow_rate, + quant_method, m, n, limit_i, limit_j, limit_k, + result.data_ptr()); } + cudaSafeCall(cudaFree(E_matrix)); } cudaSafeCall(cudaDeviceSynchronize()); cudaSafeCall(cudaFree(mat_a_tiles_shape)); diff --git a/memtorch/cu/tile_matmul_kernels.cuh b/memtorch/cu/tile_matmul_kernels.cuh index d434d172..ef90d937 100644 --- a/memtorch/cu/tile_matmul_kernels.cuh +++ b/memtorch/cu/tile_matmul_kernels.cuh @@ -2,4 +2,5 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, int mat_a_shape[2], at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, int mat_b_shape[2], int ADC_resolution, float overflow_rate, - int quant_method, int cuda_malloc_heap_size); \ No newline at end of file + int quant_method, float source_resistance, + float line_resistance, int cuda_malloc_heap_size); \ No newline at end of file diff --git a/memtorch/cu/utils.cuh b/memtorch/cu/utils.cuh index 0ee5cd77..82b6d17b 100644 --- a/memtorch/cu/utils.cuh +++ b/memtorch/cu/utils.cuh @@ -53,12 +53,13 @@ template __device__ void sort_(T *tensor, int tensor_numel) { } } -__device__ int transform_2d_index(int x, int y, int len_y) { +inline __device__ int transform_2d_index(int x, int y, int len_y) { return x * len_y + y; } -__device__ int transform_3d_index(int x, int y, int z, int len_y, int len_z) { +inline __device__ int transform_3d_index(int x, int y, int z, int len_y, + int len_z) { return x * len_y * len_z + y * len_z + z; } -int ceil_int_div(int a, int b) { return (a + b - 1) / b; } \ No newline at end of file +inline int ceil_int_div(int a, int b) { return (a + b - 1) / b; } \ No newline at end of file diff --git a/memtorch/mn/Conv1d.py b/memtorch/mn/Conv1d.py index ae0474c3..a633c77e 100644 --- a/memtorch/mn/Conv1d.py +++ b/memtorch/mn/Conv1d.py @@ -1,4 +1,5 @@ import math +import warnings import numpy as np import torch @@ -43,6 +44,10 @@ class Conv1d(nn.Conv1d): Scaling routine to use in order to scale batch inputs. scaling_routine_params : **kwargs Scaling routine keyword arguments. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float @@ -70,6 +75,8 @@ def __init__( max_input_voltage=None, scaling_routine=naive_scale, scaling_routine_params={}, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, @@ -82,13 +89,29 @@ def __init__( convolutional_layer, nn.Conv1d ), "convolutional_layer is not an instance of nn.Conv1d." self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") + self.transistor = transistor self.scheme = scheme self.tile_shape = tile_shape self.max_input_voltage = max_input_voltage self.scaling_routine = scaling_routine self.scaling_routine_params = scaling_routine_params + self.source_resistance = source_resistance + self.line_resistance = line_resistance self.ADC_resolution = ADC_resolution self.ADC_overflow_rate = ADC_overflow_rate + if "cpu" not in memtorch.__version__: + self.cuda_malloc_heap_size = 50 + else: + self.cuda_malloc_heap_size = None + + if not transistor: + assert ( + source_resistance is not None and source_resistance >= 0.0 + ), "Source resistance is invalid." + assert ( + line_resistance is not None and line_resistance >= 0.0 + ), "Line resistance is invalid." + if quant_method in memtorch.bh.Quantize.quant_methods: self.quant_method = quant_method else: @@ -184,8 +207,10 @@ def forward(self, input): .permute(1, 0, 2) .reshape(-1, self.in_channels * self.kernel_size[0]) ) - unfolded_batch_input_shape = unfolded_batch_input.shape if hasattr(self, "non_linear"): + warnings.warn( + "Non-liner modeling does not currently account for source and line resistances." + ) if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = ( @@ -223,15 +248,33 @@ def forward(self, input): ) else: if self.tile_shape is not None: - out_ = tiled_inference(unfolded_batch_input, self).T - else: - out_ = torch.matmul( - unfolded_batch_input, - self.crossbar_operation( - self.crossbars, - lambda crossbar: crossbar.conductance_matrix, - ), + out_ = tiled_inference( + unfolded_batch_input, self, transistor=self.transistor ).T + else: + devices = self.crossbar_operation( + self.crossbars, + lambda crossbar: crossbar.conductance_matrix, + ) + if self.transistor: + out_ = torch.matmul( + unfolded_batch_input, + devices, + ).T + else: + out_ = memtorch.bh.crossbar.Passive.solve_passive( + devices, + unfolded_batch_input.to(self.device), + torch.zeros( + unfolded_batch_input.shape[0], devices.shape[1] + ), + self.source_resistance, + self.line_resistance, + n_input_batches=unfolded_batch_input.shape[0], + use_bindings=self.use_bindings, + cuda_malloc_heap_size=self.cuda_malloc_heap_size, + ).T + if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, diff --git a/memtorch/mn/Conv2d.py b/memtorch/mn/Conv2d.py index a9661555..fe497eb8 100644 --- a/memtorch/mn/Conv2d.py +++ b/memtorch/mn/Conv2d.py @@ -1,4 +1,5 @@ import math +import warnings import numpy as np import torch @@ -43,6 +44,10 @@ class Conv2d(nn.Conv2d): Scaling routine to use in order to scale batch inputs. scaling_routine_params : **kwargs Scaling routine keyword arguments. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float @@ -70,6 +75,8 @@ def __init__( max_input_voltage=None, scaling_routine=naive_scale, scaling_routine_params={}, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, @@ -82,13 +89,29 @@ def __init__( convolutional_layer, nn.Conv2d ), "convolutional_layer is not an instance of nn.Conv2d." self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") + self.transistor = transistor self.scheme = scheme self.tile_shape = tile_shape self.max_input_voltage = max_input_voltage self.scaling_routine = scaling_routine self.scaling_routine_params = scaling_routine_params + self.source_resistance = source_resistance + self.line_resistance = line_resistance self.ADC_resolution = ADC_resolution self.ADC_overflow_rate = ADC_overflow_rate + if "cpu" not in memtorch.__version__: + self.cuda_malloc_heap_size = 50 + else: + self.cuda_malloc_heap_size = None + + if not transistor: + assert ( + source_resistance is not None and source_resistance >= 0.0 + ), "Source resistance is invalid." + assert ( + line_resistance is not None and line_resistance >= 0.0 + ), "Line resistance is invalid." + if quant_method in memtorch.bh.Quantize.quant_methods: self.quant_method = quant_method else: @@ -206,8 +229,10 @@ def forward(self, input): -1, self.in_channels * self.kernel_size[0] * self.kernel_size[1] ) ) - unfolded_batch_input_shape = unfolded_batch_input.shape if hasattr(self, "non_linear"): + warnings.warn( + "Non-liner modeling does not currently account for source and line resistances." + ) if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = ( @@ -245,15 +270,33 @@ def forward(self, input): ) else: if self.tile_shape is not None: - out_ = tiled_inference(unfolded_batch_input, self).T - else: - out_ = torch.matmul( - unfolded_batch_input, - self.crossbar_operation( - self.crossbars, - lambda crossbar: crossbar.conductance_matrix, - ), + out_ = tiled_inference( + unfolded_batch_input, self, transistor=self.transistor ).T + else: + devices = self.crossbar_operation( + self.crossbars, + lambda crossbar: crossbar.conductance_matrix, + ) + if self.transistor: + out_ = torch.matmul( + unfolded_batch_input, + devices, + ).T + else: + out_ = memtorch.bh.crossbar.Passive.solve_passive( + devices, + unfolded_batch_input.to(self.device), + torch.zeros( + unfolded_batch_input.shape[0], devices.shape[1] + ), + self.source_resistance, + self.line_resistance, + n_input_batches=unfolded_batch_input.shape[0], + use_bindings=self.use_bindings, + cuda_malloc_heap_size=self.cuda_malloc_heap_size, + ).T + if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, diff --git a/memtorch/mn/Conv3d.py b/memtorch/mn/Conv3d.py index ad265362..1a627070 100644 --- a/memtorch/mn/Conv3d.py +++ b/memtorch/mn/Conv3d.py @@ -1,4 +1,5 @@ import math +import warnings import numpy as np import torch @@ -43,6 +44,10 @@ class Conv3d(nn.Conv3d): Scaling routine to use in order to scale batch inputs. scaling_routine_params : **kwargs Scaling routine keyword arguments. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float @@ -70,6 +75,8 @@ def __init__( max_input_voltage=None, scaling_routine=naive_scale, scaling_routine_params={}, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, @@ -82,13 +89,29 @@ def __init__( convolutional_layer, nn.Conv3d ), "convolutional_layer is not an instance of nn.Conv3d." self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") + self.transistor = transistor self.scheme = scheme self.tile_shape = tile_shape self.max_input_voltage = max_input_voltage self.scaling_routine = scaling_routine self.scaling_routine_params = scaling_routine_params + self.source_resistance = source_resistance + self.line_resistance = line_resistance self.ADC_resolution = ADC_resolution self.ADC_overflow_rate = ADC_overflow_rate + if "cpu" not in memtorch.__version__: + self.cuda_malloc_heap_size = 50 + else: + self.cuda_malloc_heap_size = None + + if not transistor: + assert ( + source_resistance is not None and source_resistance >= 0.0 + ), "Source resistance is invalid." + assert ( + line_resistance is not None and line_resistance >= 0.0 + ), "Line resistance is invalid." + if quant_method in memtorch.bh.Quantize.quant_methods: self.quant_method = quant_method else: @@ -226,8 +249,10 @@ def forward(self, input): * self.kernel_size[2], ) ) - unfolded_batch_input_shape = unfolded_batch_input.shape if hasattr(self, "non_linear"): + warnings.warn( + "Non-liner modeling does not currently account for source and line resistances." + ) if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = ( @@ -265,15 +290,33 @@ def forward(self, input): ) else: if self.tile_shape is not None: - out_ = tiled_inference(unfolded_batch_input, self).T - else: - out_ = torch.matmul( - unfolded_batch_input, - self.crossbar_operation( - self.crossbars, - lambda crossbar: crossbar.conductance_matrix, - ), + out_ = tiled_inference( + unfolded_batch_input, self, transistor=self.transistor ).T + else: + devices = self.crossbar_operation( + self.crossbars, + lambda crossbar: crossbar.conductance_matrix, + ) + if self.transistor: + out_ = torch.matmul( + unfolded_batch_input, + devices, + ).T + else: + out_ = memtorch.bh.crossbar.Passive.solve_passive( + devices, + unfolded_batch_input.to(self.device), + torch.zeros( + unfolded_batch_input.shape[0], devices.shape[1] + ), + self.source_resistance, + self.line_resistance, + n_input_batches=unfolded_batch_input.shape[0], + use_bindings=self.use_bindings, + cuda_malloc_heap_size=self.cuda_malloc_heap_size, + ).T + if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, diff --git a/memtorch/mn/Linear.py b/memtorch/mn/Linear.py index 569d5d63..b18eb8aa 100644 --- a/memtorch/mn/Linear.py +++ b/memtorch/mn/Linear.py @@ -1,4 +1,5 @@ import math +import warnings import numpy as np import torch @@ -43,6 +44,10 @@ class Linear(nn.Linear): Scaling routine to use in order to scale batch inputs. scaling_routine_params : **kwargs Scaling routine keyword arguments. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float @@ -70,6 +75,8 @@ def __init__( max_input_voltage=None, scaling_routine=naive_scale, scaling_routine_params={}, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, @@ -82,13 +89,29 @@ def __init__( linear_layer, nn.Linear ), "linear_layer is not an instance of nn.Linear." self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") + self.transistor = transistor self.scheme = scheme self.tile_shape = tile_shape self.max_input_voltage = max_input_voltage self.scaling_routine = scaling_routine self.scaling_routine_params = scaling_routine_params + self.source_resistance = source_resistance + self.line_resistance = line_resistance self.ADC_resolution = ADC_resolution self.ADC_overflow_rate = ADC_overflow_rate + if "cpu" not in memtorch.__version__: + self.cuda_malloc_heap_size = 50 + else: + self.cuda_malloc_heap_size = None + + if not transistor: + assert ( + source_resistance is not None and source_resistance >= 0.0 + ), "Source resistance is invalid." + assert ( + line_resistance is not None and line_resistance >= 0.0 + ), "Line resistance is invalid." + if quant_method in memtorch.bh.Quantize.quant_methods: self.quant_method = quant_method else: @@ -160,9 +183,11 @@ def forward(self, input): return out else: - input_shape = input.shape input = self.scaling_routine(self, input, **self.scaling_routine_params) if hasattr(self, "non_linear"): + warnings.warn( + "Non-liner modeling does not currently account for source and line resistances." + ) if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = self.weight.data.shape @@ -193,14 +218,28 @@ def forward(self, input): ).to(self.device) else: if self.tile_shape is not None: - out_ = tiled_inference(input, self) + out_ = tiled_inference(input, self, transistor=self.transistor) else: - out_ = torch.matmul( - input.to(self.device), - self.crossbar_operation( - self.crossbars, lambda crossbar: crossbar.conductance_matrix - ), + devices = self.crossbar_operation( + self.crossbars, lambda crossbar: crossbar.conductance_matrix ) + if self.transistor: + out_ = torch.matmul( + input.to(self.device), + devices, + ) + else: + out_ = memtorch.bh.crossbar.Passive.solve_passive( + devices, + input.to(self.device), + torch.zeros(input.shape[0], devices.shape[1]), + self.source_resistance, + self.line_resistance, + n_input_batches=input.shape[0], + use_bindings=self.use_bindings, + cuda_malloc_heap_size=self.cuda_malloc_heap_size, + ) + if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, diff --git a/memtorch/mn/Module.py b/memtorch/mn/Module.py index 0ebc660d..88efe73a 100644 --- a/memtorch/mn/Module.py +++ b/memtorch/mn/Module.py @@ -36,6 +36,8 @@ def patch_model( max_input_voltage=None, scaling_routine=naive_scale, scaling_routine_params={}, + source_resistance=None, + line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, @@ -75,6 +77,10 @@ def patch_model( Scaling routine to use in order to scale batch inputs. scaling_routine_params : **kwargs Scaling routine keyword arguments. + source_resistance : float + The resistance between word/bit line voltage sources and crossbar(s). + line_resistance : float + The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float @@ -111,6 +117,8 @@ def patch_model( max_input_voltage=max_input_voltage, scaling_routine=scaling_routine, scaling_routine_params=scaling_routine_params, + source_resistance=source_resistance, + line_resistance=line_resistance, ADC_resolution=ADC_resolution, ADC_overflow_rate=ADC_overflow_rate, quant_method=quant_method, @@ -176,8 +184,15 @@ def disable_legacy(self): self.forward_legacy(False) delattr(self, "forward_legacy") + def set_cuda_malloc_heap_size(self, cuda_malloc_heap_size): + """Method to set the CUDA malloc heap size.""" + for i, (name, m) in enumerate(list(self.named_modules())): + if type(m) in supported_module_parameters.values(): + m.cuda_malloc_heap_size = cuda_malloc_heap_size + model.forward_legacy = forward_legacy.__get__(model) model.tune_ = tune_.__get__(model) model.forward_legacy(False) model.disable_legacy = disable_legacy.__get__(model) + model.set_cuda_malloc_heap_size = set_cuda_malloc_heap_size.__get__(model) return model diff --git a/memtorch/version.py b/memtorch/version.py index b544b44d..88fbfbbf 100644 --- a/memtorch/version.py +++ b/memtorch/version.py @@ -1 +1 @@ -__version__ = "1.1.3-cpu" +__version__ = "1.1.4-cpu" diff --git a/setup.py b/setup.py index 2f03869d..c9427b88 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,9 @@ import torch from setuptools import find_packages, setup -from torch.utils.cpp_extension import include_paths +from torch.utils.cpp_extension import include_paths, library_paths -version = "1.1.3" +version = "1.1.4" CUDA = False @@ -33,13 +33,23 @@ def create_version_py(version, CUDA): library_dirs=["memtorch/submodules"], include_dirs=[ os.path.join(os.getcwd(), relative_path) - for relative_path in ["memtorch/cu/", "memtorch/submodules/eigen/"] + for relative_path in [ + "memtorch/cu/", + "memtorch/submodules/eigen/", + ] ], + extra_compile_args=["-lineinfo"], ), CppExtension( name="memtorch_bindings", sources=glob.glob("memtorch/cpp/*.cpp"), - include_dirs=["memtorch/cpp/"], + include_dirs=[ + os.path.join(os.getcwd(), relative_path) + for relative_path in [ + "memtorch/cpp/", + "memtorch/submodules/eigen/", + ] + ], ), ] name = "memtorch" @@ -50,7 +60,13 @@ def create_version_py(version, CUDA): CppExtension( name="memtorch_bindings", sources=glob.glob("memtorch/cpp/*.cpp"), - include_dirs=["memtorch/cpp/"], + include_dirs=[ + os.path.join(os.getcwd(), relative_path) + for relative_path in [ + "memtorch/cpp/", + "memtorch/submodules/eigen/", + ] + ], ) ] name = "memtorch-cpu" diff --git a/tests/test_cpp_extensions.py b/tests/test_cpp_extensions.py deleted file mode 100644 index 63d91fd9..00000000 --- a/tests/test_cpp_extensions.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -import memtorch - -# if "cpu" in memtorch.__version__: -# import quantization -# else: -# import cuda_quantization as quantization - -import copy -import math -import random - -import matplotlib -import numpy as np - - -# TBD -# @pytest.mark.parametrize( -# "shape, quantization_levels", [((20, 50), 10), ((100, 100), 5)] -# ) -# def test_quantize(shape, quantization_levels): -# device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") -# tensor = torch.zeros(shape).uniform_(0, 1).to(device) -# quantized_tensor = copy.deepcopy(tensor) -# quantization.quantize(quantized_tensor, quantization_levels, 0, 1) -# valid_values = torch.linspace(0, 1, quantization_levels) -# quantized_tensor_unique = quantized_tensor.unique() -# assert any( -# [ -# bool(val) -# for val in [ -# torch.isclose(quantized_tensor_unique, valid_value).any() -# for valid_value in valid_values -# ] -# ] -# ) -# assert tensor.shape == quantized_tensor.shape -# assert math.isclose( -# min(valid_values.tolist(), key=lambda x: abs(x - tensor[0][0])), -# quantized_tensor[0][0], -# ) -# assert math.isclose( -# min(valid_values.tolist(), key=lambda x: abs(x - tensor[0][1])), -# quantized_tensor[0][1], -# ) diff --git a/tests/test_networks.py b/tests/test_networks.py index 00d6498e..093974e4 100644 --- a/tests/test_networks.py +++ b/tests/test_networks.py @@ -4,6 +4,7 @@ import numpy as np import pytest import torch +from numpy.lib import source import memtorch from memtorch.bh.crossbar.Program import naive_program @@ -13,8 +14,17 @@ @pytest.mark.parametrize("tile_shape", [None, (128, 128), (10, 20)]) @pytest.mark.parametrize("quant_method", memtorch.bh.Quantize.quant_methods + [None]) +@pytest.mark.parametrize("source_resistance", [None, 5]) +@pytest.mark.parametrize("line_resistance", [None, 5]) @pytest.mark.parametrize("use_bindings", [True, False]) -def test_networks(debug_networks, tile_shape, quant_method, use_bindings): +def test_networks( + debug_networks, + tile_shape, + quant_method, + source_resistance, + line_resistance, + use_bindings, +): networks = debug_networks if quant_method is not None: ADC_resolution = 8 @@ -35,6 +45,8 @@ def test_networks(debug_networks, tile_shape, quant_method, use_bindings): max_input_voltage=1.0, ADC_resolution=ADC_resolution, quant_method=quant_method, + source_resistance=source_resistance, + line_resistance=line_resistance, use_bindings=use_bindings, ) patched_network.tune_()