Skip to content

Commit a6776ca

Browse files
authored
Merge pull request tancheng#259 from guosran/dcu-design
LoopCounter Integration
2 parents 8316500 + bef74f6 commit a6776ca

10 files changed

Lines changed: 1066 additions & 11 deletions

cgra/test/CgraRTL_fir_2x2_loop_counter_test.py

Lines changed: 702 additions & 0 deletions
Large diffs are not rendered by default.

controller/ControllerRTL.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,11 @@ def update_received_msg():
302302
0, # vc_id
303303
s.recv_from_inter_cgra_noc.msg.payload)
304304

305+
# Consume and discard the leaf counter complete signal (loop termination
306+
# notification from LoopCounter FU) to avoid blocking the NoC.
307+
elif s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_LEAF_COUNTER_COMPLETE:
308+
s.recv_from_inter_cgra_noc.rdy @= 1
309+
305310
elif s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_GLOBAL_REDUCE_ADD:
306311
s.recv_from_inter_cgra_noc.rdy @= s.global_reduce_unit.recv_data.rdy
307312
s.global_reduce_unit.recv_data.val @= 1
@@ -327,7 +332,10 @@ def update_received_msg():
327332
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_RESUME) | \
328333
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_RECORD_PHI_ADDR) | \
329334
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_TERMINATE) | \
330-
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_LAUNCH):
335+
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_LAUNCH) | \
336+
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_LOOP_LOWER) | \
337+
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_LOOP_UPPER) | \
338+
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_LOOP_STEP) :
331339
s.recv_from_inter_cgra_noc.rdy @= s.send_to_ctrl_ring_pkt.rdy
332340
s.send_to_ctrl_ring_pkt.val @= s.recv_from_inter_cgra_noc.val
333341
s.send_to_ctrl_ring_pkt.msg @= \

fu/single/ExtractPredicateRTL.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
==========================================================================
3+
ExtractPredicateRTL.py
4+
==========================================================================
5+
Functional Unit that extracts the predicate bit from input data and outputs
6+
it as a boolean result with payload = predicate_value, predicate = 1.
7+
8+
This is used to extract loop termination signals from counter outputs.
9+
10+
Author : Shangkun LI
11+
Date : January 27, 2026
12+
13+
"""
14+
15+
from pymtl3 import *
16+
from ..basic.Fu import Fu
17+
from ...lib.opt_type import *
18+
19+
class ExtractPredicateRTL(Fu):
20+
21+
def construct(s, DataType, CtrlType, num_inports,
22+
num_outports, data_mem_size, ctrl_mem_size = 4,
23+
vector_factor_power = 0,
24+
data_bitwidth = 32):
25+
26+
super(ExtractPredicateRTL, s).construct(DataType, CtrlType,
27+
num_inports, num_outports,
28+
data_mem_size, ctrl_mem_size,
29+
1, vector_factor_power,
30+
data_bitwidth = data_bitwidth)
31+
32+
num_entries = 2
33+
FuInType = mk_bits(clog2(num_inports + 1))
34+
CountType = mk_bits(clog2(num_entries + 1))
35+
36+
s.in0 = Wire(FuInType)
37+
38+
idx_nbits = clog2(num_inports)
39+
s.in0_idx = Wire(idx_nbits)
40+
s.in0_idx //= s.in0[0:idx_nbits]
41+
42+
@update
43+
def comb_logic():
44+
45+
# Default values
46+
s.in0 @= 0
47+
for i in range(num_inports):
48+
s.recv_in[i].rdy @= b1(0)
49+
for i in range(num_outports):
50+
s.send_out[i].val @= b1(0)
51+
s.send_out[i].msg @= DataType()
52+
53+
s.recv_const.rdy @= 0
54+
s.recv_opt.rdy @= 0
55+
56+
s.send_to_ctrl_mem.val @= 0
57+
s.send_to_ctrl_mem.msg @= s.CgraPayloadType(0, 0, 0, 0, 0)
58+
s.recv_from_ctrl_mem.rdy @= 0
59+
60+
if s.recv_opt.val:
61+
if s.recv_opt.msg.fu_in[0] != FuInType(0):
62+
s.in0 @= s.recv_opt.msg.fu_in[0] - FuInType(1)
63+
64+
if s.recv_opt.val:
65+
if s.recv_opt.msg.operation == OPT_EXTRACT_PREDICATE:
66+
# Extracts predicate bit from input and output as payload.
67+
# When loop is running (predicate=1) -> payload=1
68+
# When loop terminates (predicate=0) -> payload=0
69+
# Downstream NOT will invert: running->0 (no RET), done->1 (trigger RET)
70+
s.send_out[0].msg.payload @= zext(s.recv_in[s.in0_idx].msg.predicate, DataType.get_field_type('payload'))
71+
s.send_out[0].msg.predicate @= 1
72+
73+
s.send_out[0].val @= s.recv_in[s.in0_idx].val
74+
s.recv_in[s.in0_idx].rdy @= s.recv_in[s.in0_idx].val & s.send_out[0].rdy
75+
s.recv_opt.rdy @= s.recv_in[s.in0_idx].val & s.send_out[0].rdy
76+
77+
else:
78+
for j in range(num_outports):
79+
s.send_out[j].val @= b1(0)
80+
s.recv_opt.rdy @= 0
81+
s.recv_in[s.in0_idx].rdy @= 0
82+
83+
def line_trace(s):
84+
opt_str = " #"
85+
if s.recv_opt.val:
86+
opt_str = OPT_SYMBOL_DICT[s.recv_opt.msg.operation]
87+
out_str = ",".join([str(x.msg) for x in s.send_out])
88+
recv_str = ",".join([str(x.msg) for x in s.recv_in])
89+
return f'[ExtPred|recv: {recv_str}] {opt_str} = [out: {out_str}]'

fu/single/LoopCounterRTL.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def comb_logic():
117117
# Loop terminated: predicate = 0.
118118
s.send_out[0].msg.predicate @= 0
119119

120-
# Sends CMD_COMPLETE if not already done.
120+
# Sends CMD_LEAF_COUNTER_COMPLETE if not already done.
121121
if ~s.already_done[addr]:
122122
s.send_to_ctrl_mem.val @= b1(1)
123123
s.send_to_ctrl_mem.msg @= s.CgraPayloadType(
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
==========================================================================
3+
ExtractPredicateRTL_test.py
4+
==========================================================================
5+
Test cases for ExtractPredicateRTL functional unit.
6+
7+
Author : Shangkun LI
8+
Date : January 27, 2026
9+
"""
10+
11+
import pytest
12+
13+
from pymtl3 import *
14+
from pymtl3.stdlib.test_utils import (run_sim, config_model_with_cmdline_opts)
15+
16+
from ....lib.messages import *
17+
from ....lib.opt_type import *
18+
from ..ExtractPredicateRTL import ExtractPredicateRTL
19+
from ....lib.basic.val_rdy.SourceRTL import SourceRTL as TestSrcRTL
20+
from ....lib.basic.val_rdy.SinkRTL import SinkRTL as TestSinkRTL
21+
22+
#-------------------------------------------------------------------------
23+
# Test harness
24+
#-------------------------------------------------------------------------
25+
26+
class TestHarness(Component):
27+
28+
def construct(s, FunctionUnit, DataType, CtrlType,
29+
num_inports, num_outports,
30+
data_mem_size, src_in0, src_opt, sink_out):
31+
32+
s.src_in0 = TestSrcRTL(DataType, src_in0)
33+
s.src_opt = TestSrcRTL(CtrlType, src_opt)
34+
s.sink_out = TestSinkRTL(DataType, sink_out)
35+
36+
s.dut = FunctionUnit(DataType, CtrlType,
37+
num_inports, num_outports,
38+
data_mem_size)
39+
40+
FuInType = mk_bits(clog2(num_inports + 1))
41+
42+
# Connections
43+
s.src_in0.send //= s.dut.recv_in[0]
44+
s.src_opt.send //= s.dut.recv_opt
45+
s.dut.send_out[0] //= s.sink_out.recv
46+
47+
# Tie off unused ports
48+
s.dut.recv_const.val //= 0
49+
s.dut.recv_const.msg //= DataType()
50+
for i in range(1, num_inports):
51+
s.dut.recv_in[i].val //= 0
52+
s.dut.recv_in[i].msg //= DataType()
53+
for i in range(1, num_outports):
54+
s.dut.send_out[i].rdy //= 0
55+
56+
s.dut.recv_from_ctrl_mem.val //= 0
57+
s.dut.recv_from_ctrl_mem.msg //= s.dut.CgraPayloadType()
58+
s.dut.send_to_ctrl_mem.rdy //= 0
59+
60+
def done(s):
61+
return s.src_in0.done() and s.src_opt.done() and s.sink_out.done()
62+
63+
def line_trace(s):
64+
return s.dut.line_trace()
65+
66+
def run_sim(th, max_cycles=100):
67+
th.elaborate()
68+
th.apply(DefaultPassGroup())
69+
th.sim_reset()
70+
71+
ncycles = 0
72+
print()
73+
print("{:3}: {}".format(ncycles, th.line_trace()))
74+
while not th.done() and ncycles < max_cycles:
75+
th.sim_tick()
76+
ncycles += 1
77+
print("{:3}: {}".format(ncycles, th.line_trace()))
78+
79+
assert ncycles < max_cycles
80+
th.sim_tick()
81+
th.sim_tick()
82+
th.sim_tick()
83+
84+
#-------------------------------------------------------------------------
85+
# Test cases
86+
#-------------------------------------------------------------------------
87+
88+
def test_extract_predicate_basic():
89+
"""Test basic predicate extraction"""
90+
91+
num_inports = 4
92+
num_outports = 2
93+
data_mem_size = 8
94+
95+
data_bitwidth = 32
96+
DataType = mk_data(data_bitwidth, 1)
97+
num_ctrl_operations = 64
98+
num_fu_inports = num_inports
99+
num_fu_outports = num_outports
100+
num_tile_inports = 8
101+
num_tile_outports = 8
102+
num_registers_per_reg_bank = 16
103+
CtrlType = mk_ctrl(num_fu_inports, num_fu_outports,
104+
num_tile_inports, num_tile_outports,
105+
num_registers_per_reg_bank)
106+
FuInType = mk_bits(clog2(num_inports + 1))
107+
108+
# Input data with different predicates
109+
# payload doesn't matter, only predicate is extracted
110+
src_in0 = [
111+
DataType(100, 1), # predicate = 1
112+
DataType(200, 0), # predicate = 0
113+
DataType(300, 1), # predicate = 1
114+
DataType(400, 0), # predicate = 0
115+
]
116+
117+
# Operations: all OPT_EXTRACT_PREDICATE
118+
src_opt = [
119+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
120+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
121+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
122+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
123+
]
124+
125+
# Expected outputs: payload = extracted predicate, predicate = 1 (always valid)
126+
sink_out = [
127+
DataType(1, 1), # extracted pred=1, output pred=1
128+
DataType(0, 1), # extracted pred=0, output pred=1
129+
DataType(1, 1), # extracted pred=1, output pred=1
130+
DataType(0, 1), # extracted pred=0, output pred=1
131+
]
132+
133+
th = TestHarness(ExtractPredicateRTL, DataType, CtrlType,
134+
num_inports, num_outports, data_mem_size,
135+
src_in0, src_opt, sink_out)
136+
run_sim(th)
137+
138+
def test_extract_predicate_for_loop_termination():
139+
"""Test predicate extraction for loop termination detection"""
140+
141+
num_inports = 4
142+
num_outports = 2
143+
data_mem_size = 8
144+
145+
data_bitwidth = 32
146+
DataType = mk_data(data_bitwidth, 1)
147+
num_ctrl_operations = 64
148+
num_fu_inports = num_inports
149+
num_fu_outports = num_outports
150+
num_tile_inports = 8
151+
num_tile_outports = 8
152+
num_registers_per_reg_bank = 16
153+
CtrlType = mk_ctrl(num_fu_inports, num_fu_outports,
154+
num_tile_inports, num_tile_outports,
155+
num_registers_per_reg_bank)
156+
FuInType = mk_bits(clog2(num_inports + 1))
157+
158+
# Simulating counter output pattern:
159+
# - pred=1 for valid iterations
160+
# - pred=0 when loop terminates
161+
src_in0 = [
162+
DataType(0, 1), # counter=0, pred=1 (valid)
163+
DataType(1, 1), # counter=1, pred=1 (valid)
164+
DataType(2, 1), # counter=2, pred=1 (valid)
165+
DataType(3, 0), # counter=3, pred=0 (terminated!)
166+
]
167+
168+
src_opt = [
169+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
170+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
171+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
172+
CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]),
173+
]
174+
175+
# Expected: extract predicate as boolean for use with NOT and grant_predicate
176+
sink_out = [
177+
DataType(1, 1), # pred=1 -> payload=1 (continue)
178+
DataType(1, 1), # pred=1 -> payload=1 (continue)
179+
DataType(1, 1), # pred=1 -> payload=1 (continue)
180+
DataType(0, 1), # pred=0 -> payload=0 (terminate!)
181+
]
182+
183+
th = TestHarness(ExtractPredicateRTL, DataType, CtrlType,
184+
num_inports, num_outports, data_mem_size,
185+
src_in0, src_opt, sink_out)
186+
run_sim(th)

fu/single/test/LoopCounterRTL_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_leaf_counter_basic():
137137
CgraPayloadType(CMD_LEAF_COUNTER_COMPLETE, DataType(0,0), 0, CtrlType(OPT_LOOP_COUNT), 0)
138138
]
139139

140-
ctrl_addrs = [0]*20
140+
ctrl_addrs = [0]*10
141141

142142
th = TestHarness(LoopCounterRTL, DataType, CtrlType, CgraPayloadType,
143143
num_inports, num_outports,

lib/opt_type.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
OPT_LOOP_CONTROL = OpCodeType( 83 )
109109
OPT_LOOP_COUNT = OpCodeType( 85 )
110110
OPT_LOOP_DELIVERY = OpCodeType( 86 )
111+
OPT_EXTRACT_PREDICATE = OpCodeType( 87 )
111112

112113
OPT_SYMBOL_DICT = {
113114
OPT_START : "(start)",
@@ -197,5 +198,6 @@
197198

198199
OPT_LOOP_CONTROL : "(loop_ctrl)",
199200
OPT_LOOP_COUNT : "(loop_cnt)",
200-
OPT_LOOP_DELIVERY : "(loop_deli)"
201+
OPT_LOOP_DELIVERY : "(loop_deli)",
202+
OPT_EXTRACT_PREDICATE : "(extract_pred)"
201203
}

0 commit comments

Comments
 (0)