Skip to content

Commit

Permalink
add TCM
Browse files Browse the repository at this point in the history
  • Loading branch information
mpskex committed Mar 10, 2024
1 parent 9774678 commit c06314b
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/ncore/pe/procElem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class PE(val nbits: Int = 8) extends Module {
val io = IO(
new Bundle {
val accum = Input(Bool())
val in_a = Input(UInt(nbits.W))
val in_b = Input(UInt(nbits.W))
val in_a = Input(UInt(nbits.W))
val in_b = Input(UInt(nbits.W))
// The register bandwith is optimized for large transformer
// The lower bound of max cap matrix size is:
// 2^12 x 2^12 = (4096 x 4096)
Expand Down
59 changes: 59 additions & 0 deletions src/main/scala/ncore/tcm/ tightCoupledMem.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// See README.md for license details.

package ncore.tcm

import chisel3._
import chisel3.util._
import chisel3.util.experimental.decode

class TCMCell(val nbits: Int = 8) extends Module {
val io = IO(
new Bundle {
val d_in = Input(UInt(nbits.W))
val d_out = Output(UInt(nbits.W))
val en_wr = Input(Bool())
}
)

val reg = RegInit(0.U(nbits.W))
io.d_out := reg

when (io.en_wr) {
reg := io.d_in
}
}

class TCMBlock(val n: Int = 8,
val size: Int = 4096,
val r_addr_width: Int = 12,
val w_addr_width: Int = 12,
val nbits: Int = 8
) extends Module {
val io = IO(
new Bundle {
val d_in = Input(Vec(n * n, UInt(nbits.W)))
val d_out = Output(Vec(n * n, UInt(nbits.W)))
val r_addr = Input(Vec(n * n, UInt(r_addr_width.W)))
val w_addr = Input(Vec(n * n, UInt(w_addr_width.W)))
val en_wr = Input(Bool())
}
)
val cells_io = VecInit(Seq.fill(size) {Module(new TCMCell(nbits)).io})

for (i <- 0 until size) {
cells_io(i).en_wr := false.B.asTypeOf(cells_io(i).en_wr)
// Need to initialize all wires just in case of not selected.
cells_io(i).d_in := 0.U.asTypeOf(cells_io(i).d_in)
}

//TODO: add range check
//TODO: add read & write conflict check

for (i <- 0 until n * n) {
io.d_out(i) := cells_io(io.r_addr(i)).d_out
when (io.en_wr) {
cells_io(io.w_addr(i)).en_wr := io.en_wr
cells_io(io.w_addr(i)).d_in := io.d_in(i)
}
}
}
2 changes: 1 addition & 1 deletion src/test/scala/ncore/CoreSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CoreSpec extends AnyFlatSpec with ChiselScalatestTester {
"NeuralCore" should "do a normal matrix multiplication" in {
test(new NeuralCore(4, 8)) { dut =>
val print_helper = new testUtil.PrintHelper()
val _n = 4
val _n = dut.n
val rand = new Random
val _mat_a = new Array[Int](_n * _n)
val _mat_b = new Array[Int](_n * _n)
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/ncore/cu/CUSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CUSpec extends AnyFlatSpec with ChiselScalatestTester {
"CU" should "send control to 2D systolic array" in {
test(new ControlUnit(4)) { dut =>
val print_helper = new testUtil.PrintHelper()
val _n = 4
val _n = dut.n
val rand = new Random
var history = new Array[Int](2 * _n - 1)
var prod = 0
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/ncore/pe/PESpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import chisel3.experimental.BundleLiterals._
class PESpec extends AnyFlatSpec with ChiselScalatestTester {

"PE" should "output multiplied number from top and left" in {
test(new PE(16)) { dut =>
test(new PE(8)) { dut =>
val rand = new Random
var prod = 0
for (n <- 0 until 128) {
Expand Down
100 changes: 100 additions & 0 deletions src/test/scala/ncore/tcm/TCMSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// See README.md for license details.

package ncore.tcm

import scala.util.Random
import chisel3._
import testUtil._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
import chisel3.experimental.BundleLiterals._


class TCMSpec extends AnyFlatSpec with ChiselScalatestTester {

"TCM Cells" should "write on signal" in {
test(new TCMCell(8)) { dut =>
val rand = new Random
var _prev = 0
for (i <- 0 until 10) {
val _in = rand.between(0, 255)
dut.io.d_out.expect(_prev)
dut.io.d_in.poke(_in)
dut.io.en_wr.poke(true)
dut.clock.step()
dut.io.d_in.expect(_in)
_prev = _in
println("Result tick @ " + i + ": " + dut.io.d_in.peekInt())
}
}
}

"TCM Block" should "write on signal and read anytime" in {
test(new TCMBlock(3, 192)) { dut =>
val _n = dut.n
val _cells = dut.size
val rand = new Random
val print_helper = new testUtil.PrintHelper()
val _in_data = new Array[Int](_n * _n)
for(_i <- 0 until 10){
val _in_addr = rand.shuffle((0 until _cells).toList).take(_n * _n)
for (i <- 0 until _n * _n) {
_in_data(i) = rand.between(0, 255)
dut.io.d_in(i).poke(_in_data(i))
dut.io.w_addr(i).poke(_in_addr(i))
}
dut.io.en_wr.poke(true)
dut.clock.step()
for (i <- 0 until _n * _n) {
dut.io.r_addr(i).poke(_in_addr(i))
}
for (i <- 0 until _n * _n){
dut.io.d_out(i).expect(_in_data(i))
}
println("Result tick @ " + _i + ": ")
print_helper.printMatrix(_in_data, _n)
// print_helper.printMatrix(_in_addr, _n)
print_helper.printMatrixChisel(dut.io.d_out, _n)
}
}
}

"TCM Block" should "read anytime" in {
test(new TCMBlock(2, 64)) { dut =>
val _n = dut.n
val _cells = dut.size
val rand = new Random
val print_helper = new testUtil.PrintHelper()
val _data = new Array[Int](_cells)
for (_i <- 0 until 10) {
val _in_data = new Array[Int](_n * _n)
val _in_addr = rand.shuffle((0 until _cells).toList).take(_n * _n)
for (i <- 0 until _n * _n) {
_in_data(i) = rand.between(0, 255)
dut.io.d_in(i).poke(_in_data(i))
dut.io.w_addr(i).poke(_in_addr(i))
_data(_in_addr(i)) = _in_data(i)
}
dut.io.en_wr.poke(true)
dut.clock.step()
}
for(_i <- 0 until 10){
val _r_addr = rand.shuffle((0 until _cells).toList).take(_n * _n)
val _expected = new Array[Int](_n * _n)
for (i <- 0 until _n * _n) {
dut.io.r_addr(i).poke(_r_addr(i))
}
for (i <- 0 until _n * _n) {
_expected(i) = _data(_r_addr(i))
}
println("Result tick @ " + _i + ": ")
print_helper.printMatrix(_expected, _n)
print_helper.printMatrixChisel(dut.io.d_out, _n)
for (i <- 0 until _n * _n){
dut.io.d_out(i).expect(_data(_r_addr(i)))
}
}
}
}

}

0 comments on commit c06314b

Please sign in to comment.