From c06314bd0a3cdf60f698d8eda3469cbddba2c0dc Mon Sep 17 00:00:00 2001 From: "Fangrui.Liu" Date: Sun, 10 Mar 2024 16:10:50 +0800 Subject: [PATCH] add TCM --- src/main/scala/ncore/pe/procElem.scala | 4 +- .../scala/ncore/tcm/ tightCoupledMem.scala | 59 +++++++++++ src/test/scala/ncore/CoreSpec.scala | 2 +- src/test/scala/ncore/cu/CUSpec.scala | 2 +- src/test/scala/ncore/pe/PESpec.scala | 2 +- src/test/scala/ncore/tcm/TCMSpec.scala | 100 ++++++++++++++++++ 6 files changed, 164 insertions(+), 5 deletions(-) create mode 100644 src/main/scala/ncore/tcm/ tightCoupledMem.scala create mode 100644 src/test/scala/ncore/tcm/TCMSpec.scala diff --git a/src/main/scala/ncore/pe/procElem.scala b/src/main/scala/ncore/pe/procElem.scala index ff1a662..bf88bcc 100644 --- a/src/main/scala/ncore/pe/procElem.scala +++ b/src/main/scala/ncore/pe/procElem.scala @@ -12,8 +12,8 @@ class PE(val nbits: Int = 8) extends Module { val io = IO( new Bundle { val accum = Input(Bool()) - val in_a = Input(UInt(nbits.W)) - val in_b = Input(UInt(nbits.W)) + val in_a = Input(UInt(nbits.W)) + val in_b = Input(UInt(nbits.W)) // The register bandwith is optimized for large transformer // The lower bound of max cap matrix size is: // 2^12 x 2^12 = (4096 x 4096) diff --git a/src/main/scala/ncore/tcm/ tightCoupledMem.scala b/src/main/scala/ncore/tcm/ tightCoupledMem.scala new file mode 100644 index 0000000..d438438 --- /dev/null +++ b/src/main/scala/ncore/tcm/ tightCoupledMem.scala @@ -0,0 +1,59 @@ +// See README.md for license details. + +package ncore.tcm + +import chisel3._ +import chisel3.util._ +import chisel3.util.experimental.decode + +class TCMCell(val nbits: Int = 8) extends Module { + val io = IO( + new Bundle { + val d_in = Input(UInt(nbits.W)) + val d_out = Output(UInt(nbits.W)) + val en_wr = Input(Bool()) + } + ) + + val reg = RegInit(0.U(nbits.W)) + io.d_out := reg + + when (io.en_wr) { + reg := io.d_in + } +} + +class TCMBlock(val n: Int = 8, + val size: Int = 4096, + val r_addr_width: Int = 12, + val w_addr_width: Int = 12, + val nbits: Int = 8 +) extends Module { + val io = IO( + new Bundle { + val d_in = Input(Vec(n * n, UInt(nbits.W))) + val d_out = Output(Vec(n * n, UInt(nbits.W))) + val r_addr = Input(Vec(n * n, UInt(r_addr_width.W))) + val w_addr = Input(Vec(n * n, UInt(w_addr_width.W))) + val en_wr = Input(Bool()) + } + ) + val cells_io = VecInit(Seq.fill(size) {Module(new TCMCell(nbits)).io}) + + for (i <- 0 until size) { + cells_io(i).en_wr := false.B.asTypeOf(cells_io(i).en_wr) + // Need to initialize all wires just in case of not selected. + cells_io(i).d_in := 0.U.asTypeOf(cells_io(i).d_in) + } + + //TODO: add range check + //TODO: add read & write conflict check + + for (i <- 0 until n * n) { + io.d_out(i) := cells_io(io.r_addr(i)).d_out + when (io.en_wr) { + cells_io(io.w_addr(i)).en_wr := io.en_wr + cells_io(io.w_addr(i)).d_in := io.d_in(i) + } + } +} diff --git a/src/test/scala/ncore/CoreSpec.scala b/src/test/scala/ncore/CoreSpec.scala index 0f80604..868e733 100644 --- a/src/test/scala/ncore/CoreSpec.scala +++ b/src/test/scala/ncore/CoreSpec.scala @@ -14,7 +14,7 @@ class CoreSpec extends AnyFlatSpec with ChiselScalatestTester { "NeuralCore" should "do a normal matrix multiplication" in { test(new NeuralCore(4, 8)) { dut => val print_helper = new testUtil.PrintHelper() - val _n = 4 + val _n = dut.n val rand = new Random val _mat_a = new Array[Int](_n * _n) val _mat_b = new Array[Int](_n * _n) diff --git a/src/test/scala/ncore/cu/CUSpec.scala b/src/test/scala/ncore/cu/CUSpec.scala index 03c02a7..b0863d4 100644 --- a/src/test/scala/ncore/cu/CUSpec.scala +++ b/src/test/scala/ncore/cu/CUSpec.scala @@ -14,7 +14,7 @@ class CUSpec extends AnyFlatSpec with ChiselScalatestTester { "CU" should "send control to 2D systolic array" in { test(new ControlUnit(4)) { dut => val print_helper = new testUtil.PrintHelper() - val _n = 4 + val _n = dut.n val rand = new Random var history = new Array[Int](2 * _n - 1) var prod = 0 diff --git a/src/test/scala/ncore/pe/PESpec.scala b/src/test/scala/ncore/pe/PESpec.scala index bd9c3ae..b08d114 100644 --- a/src/test/scala/ncore/pe/PESpec.scala +++ b/src/test/scala/ncore/pe/PESpec.scala @@ -12,7 +12,7 @@ import chisel3.experimental.BundleLiterals._ class PESpec extends AnyFlatSpec with ChiselScalatestTester { "PE" should "output multiplied number from top and left" in { - test(new PE(16)) { dut => + test(new PE(8)) { dut => val rand = new Random var prod = 0 for (n <- 0 until 128) { diff --git a/src/test/scala/ncore/tcm/TCMSpec.scala b/src/test/scala/ncore/tcm/TCMSpec.scala new file mode 100644 index 0000000..5e8f311 --- /dev/null +++ b/src/test/scala/ncore/tcm/TCMSpec.scala @@ -0,0 +1,100 @@ +// See README.md for license details. + +package ncore.tcm + +import scala.util.Random +import chisel3._ +import testUtil._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec +import chisel3.experimental.BundleLiterals._ + + +class TCMSpec extends AnyFlatSpec with ChiselScalatestTester { + + "TCM Cells" should "write on signal" in { + test(new TCMCell(8)) { dut => + val rand = new Random + var _prev = 0 + for (i <- 0 until 10) { + val _in = rand.between(0, 255) + dut.io.d_out.expect(_prev) + dut.io.d_in.poke(_in) + dut.io.en_wr.poke(true) + dut.clock.step() + dut.io.d_in.expect(_in) + _prev = _in + println("Result tick @ " + i + ": " + dut.io.d_in.peekInt()) + } + } + } + + "TCM Block" should "write on signal and read anytime" in { + test(new TCMBlock(3, 192)) { dut => + val _n = dut.n + val _cells = dut.size + val rand = new Random + val print_helper = new testUtil.PrintHelper() + val _in_data = new Array[Int](_n * _n) + for(_i <- 0 until 10){ + val _in_addr = rand.shuffle((0 until _cells).toList).take(_n * _n) + for (i <- 0 until _n * _n) { + _in_data(i) = rand.between(0, 255) + dut.io.d_in(i).poke(_in_data(i)) + dut.io.w_addr(i).poke(_in_addr(i)) + } + dut.io.en_wr.poke(true) + dut.clock.step() + for (i <- 0 until _n * _n) { + dut.io.r_addr(i).poke(_in_addr(i)) + } + for (i <- 0 until _n * _n){ + dut.io.d_out(i).expect(_in_data(i)) + } + println("Result tick @ " + _i + ": ") + print_helper.printMatrix(_in_data, _n) + // print_helper.printMatrix(_in_addr, _n) + print_helper.printMatrixChisel(dut.io.d_out, _n) + } + } + } + + "TCM Block" should "read anytime" in { + test(new TCMBlock(2, 64)) { dut => + val _n = dut.n + val _cells = dut.size + val rand = new Random + val print_helper = new testUtil.PrintHelper() + val _data = new Array[Int](_cells) + for (_i <- 0 until 10) { + val _in_data = new Array[Int](_n * _n) + val _in_addr = rand.shuffle((0 until _cells).toList).take(_n * _n) + for (i <- 0 until _n * _n) { + _in_data(i) = rand.between(0, 255) + dut.io.d_in(i).poke(_in_data(i)) + dut.io.w_addr(i).poke(_in_addr(i)) + _data(_in_addr(i)) = _in_data(i) + } + dut.io.en_wr.poke(true) + dut.clock.step() + } + for(_i <- 0 until 10){ + val _r_addr = rand.shuffle((0 until _cells).toList).take(_n * _n) + val _expected = new Array[Int](_n * _n) + for (i <- 0 until _n * _n) { + dut.io.r_addr(i).poke(_r_addr(i)) + } + for (i <- 0 until _n * _n) { + _expected(i) = _data(_r_addr(i)) + } + println("Result tick @ " + _i + ": ") + print_helper.printMatrix(_expected, _n) + print_helper.printMatrixChisel(dut.io.d_out, _n) + for (i <- 0 until _n * _n){ + dut.io.d_out(i).expect(_data(_r_addr(i))) + } + } + } + } + +} \ No newline at end of file