From 894e221762118e3babeb150752c496c5bb9d939c Mon Sep 17 00:00:00 2001 From: calebmkim <55243755+calebmkim@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:46:58 -0400 Subject: [PATCH] Fix test case for duplicate FSM optimization (#2204) @parthsarkar17 The runt test for duplicate FSMs was passing in the incorrect parameter. What probably happened was that it was using an old pass input parameter name that was updated but then the runt command was not updated. This is one of the reasons why we should have at least one snapshot test for the FSM optimizations, so things like this are easier to catch (not a big deal at all, but I think it's good practice from now on). --------- Co-authored-by: Parth Sarkar --- .../src/passes/top_down_compile_control.rs | 131 ++++++++++-------- runt.toml | 2 +- tests/passes/tdcc-duplicate/seq.expect | 53 +++++++ tests/passes/tdcc-duplicate/seq.futil | 36 +++++ 4 files changed, 161 insertions(+), 61 deletions(-) create mode 100644 tests/passes/tdcc-duplicate/seq.expect create mode 100644 tests/passes/tdcc-duplicate/seq.futil diff --git a/calyx-opt/src/passes/top_down_compile_control.rs b/calyx-opt/src/passes/top_down_compile_control.rs index 7b47c98f21..398f473332 100644 --- a/calyx-opt/src/passes/top_down_compile_control.rs +++ b/calyx-opt/src/passes/top_down_compile_control.rs @@ -212,6 +212,36 @@ fn compute_unique_ids(con: &mut ir::Control, cur_state: u64) -> u64 { } } +/// Given the state of the FSM, returns the index for the register in `fsms`` +/// that should be queried. +/// A query for each state must read from one of the `num_registers` registers. +/// For `r` registers and `n` states, we split into "buckets" as follows: +/// `{0, ... , n/r - 1} -> reg. @ index 0`, +/// `{n/r, ... , 2n/r - 1} -> reg. @ index 1`, +/// ..., +/// `{(r-1)n/r, ... , n - 1} -> reg. @ index n - 1`. +/// Note that dividing each state by the value `n/r`normalizes the state w.r.t. +/// the FSM register from which it should read. We can then take the floor of this value +/// (or, equivalently, use unsigned integer division) to get this register index. +fn register_to_query( + state: u64, + num_states: u64, + num_registers: u64, + distribute: bool, +) -> usize { + match distribute { + true => { + // num_states+1 is needed to prevent error (the done condition needs + // to check past the number of states, i.e., will check fsm == 3 when + // num_states == 3). + (state * num_registers / (num_states + 1)) + .try_into() + .unwrap() + } + false => 0, + } +} + #[derive(Clone, Copy)] enum RegisterEncoding { Binary, @@ -352,18 +382,36 @@ impl<'b, 'a> Schedule<'b, 'a> { .unwrap(); }); } - // Queries the FSM by building a new slicer and corresponding assignments if + + /// First chooses which register to query from (only relevant in the duplication case.) + /// Then queries the FSM by building a new slicer and corresponding assignments if /// the query hasn't yet been made. If this query has been made before with one-hot /// encoding, it reuses the old query, but always returns a new guard representing the query. - fn build_query( + fn query_state( builder: &mut ir::Builder, - used_slicers: &mut HashMap>, + used_slicers_vec: &mut [HashMap>], fsm_rep: &FSMRepresentation, - fsm: &RRC, - signal_on: &RRC, + hardware: (&[RRC], &RRC), state: &u64, fsm_size: &u64, + distribute: bool, ) -> ir::Guard { + let (fsms, signal_on) = hardware; + let (fsm, used_slicers) = { + let reg_to_query = register_to_query( + *state, + fsm_rep.last_state, + fsms.len().try_into().unwrap(), + distribute, + ); + ( + fsms.get(reg_to_query) + .expect("the register at this index does not exist"), + used_slicers_vec + .get_mut(reg_to_query) + .expect("the used slicer map at this index does not exist"), + ) + }; match fsm_rep.encoding { RegisterEncoding::Binary => { let state_const = builder.add_constant(*state, *fsm_size); @@ -458,28 +506,6 @@ impl<'b, 'a> Schedule<'b, 'a> { (fsms, first_state, fsm_size) } - /// Given the state of the FSM, returns the index for the register in `fsms`` - /// that should be queried. - /// A query for each state must read from one of the `num_registers` registers. - /// For `r` registers and `n` states, we split into "buckets" as follows: - /// `{0, ... , n/r - 1} -> reg. @ index 0`, - /// `{n/r, ... , 2n/r - 1} -> reg. @ index 1`, - /// ..., - /// `{(r-1)n/r, ... , n - 1} -> reg. @ index n - 1`. - /// Note that dividing each state by the value `n/r`normalizes the state w.r.t. - /// the FSM register from which it should read. We can then take the floor of this value - /// (or, equivalently, use unsigned integer division) to get this register index. - fn register_to_query( - state: u64, - num_states: u64, - fsms: &[RRC], - ) -> usize { - let num_registers: u64 = fsms.len().try_into().unwrap(); - let reg_to_query: usize = - (state * num_registers / (num_states)).try_into().unwrap(); - reg_to_query - } - /// Implement a given [Schedule] and return the name of the [ir::Group] that /// implements it. fn realize_schedule( @@ -536,24 +562,15 @@ impl<'b, 'a> Schedule<'b, 'a> { .into_iter() .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) .flat_map(|(state, mut assigns)| { - // find the register from which to query; try to split evenly among registers - let (fsm, used_slicers) = { - let reg_to_query = Self::register_to_query(state, fsm_rep.last_state, &fsms); - ( - fsms.get(reg_to_query).expect("the register at this index does not exist"), - used_slicers_vec.get_mut(reg_to_query).expect( - "the used slicer map at this index does not exist", - ), - )}; // for every assignment dependent on current fsm state, `&` new guard with existing guard - let state_guard = Self::build_query( + let state_guard = Self::query_state( self.builder, - used_slicers, + &mut used_slicers_vec, &fsm_rep, - fsm, - &signal_on, + (&fsms, &signal_on), &state, &fsm_size, + true, // by default attempt to distribute across regs if >=2 exist ); assigns.iter_mut().for_each(|asgn| { asgn.guard.update(|g| g.and(state_guard.clone())) @@ -568,16 +585,14 @@ impl<'b, 'a> Schedule<'b, 'a> { group.borrow_mut().assignments.extend( self.transitions.into_iter().flat_map(|(s, e, guard)| { // get a transition guard for the first fsm register, and apply it to every fsm register - let state_guard = Self::build_query( + let state_guard = Self::query_state( self.builder, - used_slicers_vec.get_mut(0).expect( - "the used slicer map at this index 0 does not exist", - ), + &mut used_slicers_vec, &fsm_rep, - fsms.first().expect("register 0 does not exist"), - &signal_on, + (&fsms, &signal_on), &s, &fsm_size, + false, // by default do not distribute transition queries across regs; choose first ); // add transitions for every fsm register to ensure consistency between each @@ -620,16 +635,14 @@ impl<'b, 'a> Schedule<'b, 'a> { // done condition for group // arbitrarily look at first fsm register, since all are identical - let first_fsm_last_guard = Self::build_query( + let first_fsm_last_guard = Self::query_state( self.builder, - used_slicers_vec - .get_mut(0) - .expect("the used slicer map at this index does not exist"), + &mut used_slicers_vec, &fsm_rep, - fsm1, - &signal_on, + (&fsms, &signal_on), &fsm_rep.last_state, &fsm_size, + false, ); let done_assign = self.builder.build_assignment( @@ -643,18 +656,16 @@ impl<'b, 'a> Schedule<'b, 'a> { // Cleanup: Add a transition from last state to the first state for each register let reset_fsms = fsms .iter() - .enumerate() - .flat_map(|(i, fsm)| { - let fsm_last_guard = Self::build_query( + .flat_map(|fsm| { + // by default, query first register + let fsm_last_guard = Self::query_state( self.builder, - used_slicers_vec.get_mut(i).expect( - "the used slicer map at this index does not exist", - ), + &mut used_slicers_vec, &fsm_rep, - fsm, - &signal_on, + (&fsms, &signal_on), &fsm_rep.last_state, &fsm_size, + false, ); let reset_fsm = build_assignments!(self.builder; fsm["in"] = fsm_last_guard ? first_state["out"]; diff --git a/runt.toml b/runt.toml index 3e7f19d259..61e3cb1906 100644 --- a/runt.toml +++ b/runt.toml @@ -258,7 +258,7 @@ fud exec --from calyx --to jq \ --through verilog \ --through dat \ -s calyx.exec './target/debug/calyx' \ - -s calyx.flags '-x tdcc:spread=duplicate -d static-promotion' \ + -s calyx.flags '-x tdcc:duplicate-cutoff=0 -d static-promotion' \ -s verilog.cycle_limit 500 \ -s verilog.data {}.data \ -s jq.expr ".memories" \ diff --git a/tests/passes/tdcc-duplicate/seq.expect b/tests/passes/tdcc-duplicate/seq.expect new file mode 100644 index 0000000000..6876acdf57 --- /dev/null +++ b/tests/passes/tdcc-duplicate/seq.expect @@ -0,0 +1,53 @@ +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + @generated fsm0 = std_reg(2); + @generated fsm1 = std_reg(2); + } + wires { + group A { + a.in = 2'd0; + a.write_en = 1'd1; + A[done] = a.done; + } + group B { + b.in = 2'd1; + b.write_en = 1'd1; + B[done] = b.done; + } + group C { + c.in = 2'd2; + c.write_en = 1'd1; + C[done] = c.done; + } + group tdcc { + A[go] = !A[done] & fsm0.out == 2'd0 ? 1'd1; + B[go] = !B[done] & fsm0.out == 2'd1 ? 1'd1; + C[go] = !C[done] & fsm1.out == 2'd2 ? 1'd1; + fsm0.in = fsm0.out == 2'd0 & A[done] ? 2'd1; + fsm0.write_en = fsm0.out == 2'd0 & A[done] ? 1'd1; + fsm1.in = fsm0.out == 2'd0 & A[done] ? 2'd1; + fsm1.write_en = fsm0.out == 2'd0 & A[done] ? 1'd1; + fsm0.in = fsm0.out == 2'd1 & B[done] ? 2'd2; + fsm0.write_en = fsm0.out == 2'd1 & B[done] ? 1'd1; + fsm1.in = fsm0.out == 2'd1 & B[done] ? 2'd2; + fsm1.write_en = fsm0.out == 2'd1 & B[done] ? 1'd1; + fsm0.in = fsm0.out == 2'd2 & C[done] ? 2'd3; + fsm0.write_en = fsm0.out == 2'd2 & C[done] ? 1'd1; + fsm1.in = fsm0.out == 2'd2 & C[done] ? 2'd3; + fsm1.write_en = fsm0.out == 2'd2 & C[done] ? 1'd1; + tdcc[done] = fsm0.out == 2'd3 ? 1'd1; + } + fsm0.in = fsm0.out == 2'd3 ? 2'd0; + fsm0.write_en = fsm0.out == 2'd3 ? 1'd1; + fsm1.in = fsm0.out == 2'd3 ? 2'd0; + fsm1.write_en = fsm0.out == 2'd3 ? 1'd1; + } + control { + tdcc; + } +} diff --git a/tests/passes/tdcc-duplicate/seq.futil b/tests/passes/tdcc-duplicate/seq.futil new file mode 100644 index 0000000000..1979af851a --- /dev/null +++ b/tests/passes/tdcc-duplicate/seq.futil @@ -0,0 +1,36 @@ +// -x tdcc:duplicate-cutoff=0 -p tdcc + +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; + +component main() -> () { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + } + + wires { + group A { + a.in = 2'd0; + a.write_en = 1'b1; + A[done] = a.done; + } + + group B { + b.in = 2'd1; + b.write_en = 1'b1; + B[done] = b.done; + } + + group C { + c.in = 2'd2; + c.write_en = 1'b1; + C[done] = c.done; + } + } + + control { + seq { A; B; C; } + } +}