|
| 1 | +""" |
| 2 | +========================================================================== |
| 3 | +ExtractPredicateRTL_test.py |
| 4 | +========================================================================== |
| 5 | +Test cases for ExtractPredicateRTL functional unit. |
| 6 | +
|
| 7 | +Author : Shangkun LI |
| 8 | + Date : January 27, 2026 |
| 9 | +""" |
| 10 | + |
| 11 | +import pytest |
| 12 | + |
| 13 | +from pymtl3 import * |
| 14 | +from pymtl3.stdlib.test_utils import (run_sim, config_model_with_cmdline_opts) |
| 15 | + |
| 16 | +from ....lib.messages import * |
| 17 | +from ....lib.opt_type import * |
| 18 | +from ..ExtractPredicateRTL import ExtractPredicateRTL |
| 19 | +from ....lib.basic.val_rdy.SourceRTL import SourceRTL as TestSrcRTL |
| 20 | +from ....lib.basic.val_rdy.SinkRTL import SinkRTL as TestSinkRTL |
| 21 | + |
| 22 | +#------------------------------------------------------------------------- |
| 23 | +# Test harness |
| 24 | +#------------------------------------------------------------------------- |
| 25 | + |
| 26 | +class TestHarness(Component): |
| 27 | + |
| 28 | + def construct(s, FunctionUnit, DataType, CtrlType, |
| 29 | + num_inports, num_outports, |
| 30 | + data_mem_size, src_in0, src_opt, sink_out): |
| 31 | + |
| 32 | + s.src_in0 = TestSrcRTL(DataType, src_in0) |
| 33 | + s.src_opt = TestSrcRTL(CtrlType, src_opt) |
| 34 | + s.sink_out = TestSinkRTL(DataType, sink_out) |
| 35 | + |
| 36 | + s.dut = FunctionUnit(DataType, CtrlType, |
| 37 | + num_inports, num_outports, |
| 38 | + data_mem_size) |
| 39 | + |
| 40 | + FuInType = mk_bits(clog2(num_inports + 1)) |
| 41 | + |
| 42 | + # Connections |
| 43 | + s.src_in0.send //= s.dut.recv_in[0] |
| 44 | + s.src_opt.send //= s.dut.recv_opt |
| 45 | + s.dut.send_out[0] //= s.sink_out.recv |
| 46 | + |
| 47 | + # Tie off unused ports |
| 48 | + s.dut.recv_const.val //= 0 |
| 49 | + s.dut.recv_const.msg //= DataType() |
| 50 | + for i in range(1, num_inports): |
| 51 | + s.dut.recv_in[i].val //= 0 |
| 52 | + s.dut.recv_in[i].msg //= DataType() |
| 53 | + for i in range(1, num_outports): |
| 54 | + s.dut.send_out[i].rdy //= 0 |
| 55 | + |
| 56 | + s.dut.recv_from_ctrl_mem.val //= 0 |
| 57 | + s.dut.recv_from_ctrl_mem.msg //= s.dut.CgraPayloadType() |
| 58 | + s.dut.send_to_ctrl_mem.rdy //= 0 |
| 59 | + |
| 60 | + def done(s): |
| 61 | + return s.src_in0.done() and s.src_opt.done() and s.sink_out.done() |
| 62 | + |
| 63 | + def line_trace(s): |
| 64 | + return s.dut.line_trace() |
| 65 | + |
| 66 | +def run_sim(th, max_cycles=100): |
| 67 | + th.elaborate() |
| 68 | + th.apply(DefaultPassGroup()) |
| 69 | + th.sim_reset() |
| 70 | + |
| 71 | + ncycles = 0 |
| 72 | + print() |
| 73 | + print("{:3}: {}".format(ncycles, th.line_trace())) |
| 74 | + while not th.done() and ncycles < max_cycles: |
| 75 | + th.sim_tick() |
| 76 | + ncycles += 1 |
| 77 | + print("{:3}: {}".format(ncycles, th.line_trace())) |
| 78 | + |
| 79 | + assert ncycles < max_cycles |
| 80 | + th.sim_tick() |
| 81 | + th.sim_tick() |
| 82 | + th.sim_tick() |
| 83 | + |
| 84 | +#------------------------------------------------------------------------- |
| 85 | +# Test cases |
| 86 | +#------------------------------------------------------------------------- |
| 87 | + |
| 88 | +def test_extract_predicate_basic(): |
| 89 | + """Test basic predicate extraction""" |
| 90 | + |
| 91 | + num_inports = 4 |
| 92 | + num_outports = 2 |
| 93 | + data_mem_size = 8 |
| 94 | + |
| 95 | + data_bitwidth = 32 |
| 96 | + DataType = mk_data(data_bitwidth, 1) |
| 97 | + num_ctrl_operations = 64 |
| 98 | + num_fu_inports = num_inports |
| 99 | + num_fu_outports = num_outports |
| 100 | + num_tile_inports = 8 |
| 101 | + num_tile_outports = 8 |
| 102 | + num_registers_per_reg_bank = 16 |
| 103 | + CtrlType = mk_ctrl(num_fu_inports, num_fu_outports, |
| 104 | + num_tile_inports, num_tile_outports, |
| 105 | + num_registers_per_reg_bank) |
| 106 | + FuInType = mk_bits(clog2(num_inports + 1)) |
| 107 | + |
| 108 | + # Input data with different predicates |
| 109 | + # payload doesn't matter, only predicate is extracted |
| 110 | + src_in0 = [ |
| 111 | + DataType(100, 1), # predicate = 1 |
| 112 | + DataType(200, 0), # predicate = 0 |
| 113 | + DataType(300, 1), # predicate = 1 |
| 114 | + DataType(400, 0), # predicate = 0 |
| 115 | + ] |
| 116 | + |
| 117 | + # Operations: all OPT_EXTRACT_PREDICATE |
| 118 | + src_opt = [ |
| 119 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 120 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 121 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 122 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 123 | + ] |
| 124 | + |
| 125 | + # Expected outputs: payload = extracted predicate, predicate = 1 (always valid) |
| 126 | + sink_out = [ |
| 127 | + DataType(1, 1), # extracted pred=1, output pred=1 |
| 128 | + DataType(0, 1), # extracted pred=0, output pred=1 |
| 129 | + DataType(1, 1), # extracted pred=1, output pred=1 |
| 130 | + DataType(0, 1), # extracted pred=0, output pred=1 |
| 131 | + ] |
| 132 | + |
| 133 | + th = TestHarness(ExtractPredicateRTL, DataType, CtrlType, |
| 134 | + num_inports, num_outports, data_mem_size, |
| 135 | + src_in0, src_opt, sink_out) |
| 136 | + run_sim(th) |
| 137 | + |
| 138 | +def test_extract_predicate_for_loop_termination(): |
| 139 | + """Test predicate extraction for loop termination detection""" |
| 140 | + |
| 141 | + num_inports = 4 |
| 142 | + num_outports = 2 |
| 143 | + data_mem_size = 8 |
| 144 | + |
| 145 | + data_bitwidth = 32 |
| 146 | + DataType = mk_data(data_bitwidth, 1) |
| 147 | + num_ctrl_operations = 64 |
| 148 | + num_fu_inports = num_inports |
| 149 | + num_fu_outports = num_outports |
| 150 | + num_tile_inports = 8 |
| 151 | + num_tile_outports = 8 |
| 152 | + num_registers_per_reg_bank = 16 |
| 153 | + CtrlType = mk_ctrl(num_fu_inports, num_fu_outports, |
| 154 | + num_tile_inports, num_tile_outports, |
| 155 | + num_registers_per_reg_bank) |
| 156 | + FuInType = mk_bits(clog2(num_inports + 1)) |
| 157 | + |
| 158 | + # Simulating counter output pattern: |
| 159 | + # - pred=1 for valid iterations |
| 160 | + # - pred=0 when loop terminates |
| 161 | + src_in0 = [ |
| 162 | + DataType(0, 1), # counter=0, pred=1 (valid) |
| 163 | + DataType(1, 1), # counter=1, pred=1 (valid) |
| 164 | + DataType(2, 1), # counter=2, pred=1 (valid) |
| 165 | + DataType(3, 0), # counter=3, pred=0 (terminated!) |
| 166 | + ] |
| 167 | + |
| 168 | + src_opt = [ |
| 169 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 170 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 171 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 172 | + CtrlType(OPT_EXTRACT_PREDICATE, fu_in = [FuInType(1), FuInType(0), FuInType(0), FuInType(0)]), |
| 173 | + ] |
| 174 | + |
| 175 | + # Expected: extract predicate as boolean for use with NOT and grant_predicate |
| 176 | + sink_out = [ |
| 177 | + DataType(1, 1), # pred=1 -> payload=1 (continue) |
| 178 | + DataType(1, 1), # pred=1 -> payload=1 (continue) |
| 179 | + DataType(1, 1), # pred=1 -> payload=1 (continue) |
| 180 | + DataType(0, 1), # pred=0 -> payload=0 (terminate!) |
| 181 | + ] |
| 182 | + |
| 183 | + th = TestHarness(ExtractPredicateRTL, DataType, CtrlType, |
| 184 | + num_inports, num_outports, data_mem_size, |
| 185 | + src_in0, src_opt, sink_out) |
| 186 | + run_sim(th) |
0 commit comments