diff --git a/calyx-opt/src/passes/top_down_compile_control.rs b/calyx-opt/src/passes/top_down_compile_control.rs index 453993b16e..7b47c98f21 100644 --- a/calyx-opt/src/passes/top_down_compile_control.rs +++ b/calyx-opt/src/passes/top_down_compile_control.rs @@ -3,7 +3,9 @@ use crate::passes; use crate::traversal::{ Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor, }; -use calyx_ir::{self as ir, GetAttributes, LibrarySignatures, Printer, RRC}; +use calyx_ir::{ + self as ir, BoolAttr, Cell, GetAttributes, LibrarySignatures, Printer, RRC, +}; use calyx_ir::{build_assignments, guard, structure, Id}; use calyx_utils::Error; use calyx_utils::{CalyxResult, OutputFile}; @@ -17,6 +19,7 @@ use std::rc::Rc; const NODE_ID: ir::Attribute = ir::Attribute::Internal(ir::InternalAttr::NODE_ID); +const DUPLICATE_NUM_REG: u64 = 2; /// Computes the exit edges of a given [ir::Control] program. /// @@ -209,10 +212,30 @@ fn compute_unique_ids(con: &mut ir::Control, cur_state: u64) -> u64 { } } -enum Encoding { +#[derive(Clone, Copy)] +enum RegisterEncoding { Binary, OneHot, } +#[derive(Clone, Copy)] +enum RegisterSpread { + // Default option: just a single register + Single, + // Duplicate the register to reduce fanout when querying + // (all FSMs in this vec still have all of the states) + Duplicate, +} + +#[derive(Clone, Copy)] +/// A type that represents how the FSM should be implemented in hardware. +struct FSMRepresentation { + // the representation of a state within a register (one-hot, binary, etc.) + encoding: RegisterEncoding, + // the number of registers representing the dynamic finite state machine + spread: RegisterSpread, + // the index of the last state in the fsm (total # states = last_state + 1) + last_state: u64, +} /// Represents the dyanmic execution schedule of a control program. struct Schedule<'b, 'a: 'b> { @@ -329,53 +352,143 @@ impl<'b, 'a> Schedule<'b, 'a> { .unwrap(); }); } - - /// Queries the FSM by building a new slicer and corresponding assignments if - /// the query hasn't yet been made. If this query has been made before, it - /// reuses the old query. Returns a new guard representing the query. - fn build_one_hot_query( + // Queries the FSM by building a new slicer and corresponding assignments if + /// the query hasn't yet been made. If this query has been made before with one-hot + /// encoding, it reuses the old query, but always returns a new guard representing the query. + fn build_query( builder: &mut ir::Builder, - used_slicers: &mut HashMap>, - fsm: &ir::RRC, - signal_on: &ir::RRC, + used_slicers: &mut HashMap>, + fsm_rep: &FSMRepresentation, + fsm: &RRC, + signal_on: &RRC, state: &u64, fsm_size: &u64, ) -> ir::Guard { - match used_slicers.get(state) { - None => { - // construct slicer for this bit query - structure!( - builder; - let slicer = prim std_bit_slice(*fsm_size, *state, *state, 1); - ); - // build wire from fsm to slicer - let fsm_to_slicer = builder.build_assignment( - slicer.borrow().get("in"), - fsm.borrow().get("out"), - ir::Guard::True, - ); - // add continuous assignments to slicer - builder.component.continuous_assignments.push(fsm_to_slicer); - // create a guard representing when to allow next-state transition - let state_guard = guard!(slicer["out"] == signal_on["out"]); - used_slicers.insert(*state, slicer); + match fsm_rep.encoding { + RegisterEncoding::Binary => { + let state_const = builder.add_constant(*state, *fsm_size); + let state_guard = guard!(fsm["out"] == state_const["out"]); state_guard } - Some(slicer) => { - let state_guard = guard!(slicer["out"] == signal_on["out"]); - state_guard + RegisterEncoding::OneHot => { + match used_slicers.get(state) { + None => { + // construct slicer for this bit query + structure!( + builder; + let slicer = prim std_bit_slice(*fsm_size, *state, *state, 1); + ); + // build wire from fsm to slicer + let fsm_to_slicer = builder.build_assignment( + slicer.borrow().get("in"), + fsm.borrow().get("out"), + ir::Guard::True, + ); + // add continuous assignments to slicer + builder + .component + .continuous_assignments + .push(fsm_to_slicer); + // create a guard representing when to allow next-state transition + let state_guard = + guard!(slicer["out"] == signal_on["out"]); + used_slicers.insert(*state, slicer); + state_guard + } + Some(slicer) => { + let state_guard = + guard!(slicer["out"] == signal_on["out"]); + state_guard + } + } } } } + /// Builds the register(s) and constants needed for a given encoding and spread type. + fn build_fsm_infrastructure( + builder: &mut ir::Builder, + fsm_rep: &FSMRepresentation, + ) -> (Vec>, RRC, u64) { + // get fsm bit width and build constant emitting fsm first state + let (fsm_size, first_state) = match fsm_rep.encoding { + RegisterEncoding::Binary => { + let fsm_size = get_bit_width_from(fsm_rep.last_state + 1); + (fsm_size, builder.add_constant(0, fsm_size)) + } + RegisterEncoding::OneHot => { + let fsm_size = fsm_rep.last_state + 1; + (fsm_size, builder.add_constant(1, fsm_size)) + } + }; + + // for the given number of fsm registers to read from, add a primitive register to the design for each + let mut add_fsm_regs = |prim_name: &str, num_regs: u64| { + (0..num_regs) + .map(|n| { + let fsm_name = if num_regs == 1 { + "fsm".to_string() + } else { + format!("fsm{}", n) + }; + builder.add_primitive( + fsm_name.as_str(), + prim_name, + &[fsm_size], + ) + }) + .collect_vec() + }; + + let fsms = match (fsm_rep.encoding, fsm_rep.spread) { + (RegisterEncoding::Binary, RegisterSpread::Single) => { + add_fsm_regs("std_reg", 1) + } + (RegisterEncoding::OneHot, RegisterSpread::Single) => { + add_fsm_regs("init_one_reg", 1) + } + (RegisterEncoding::Binary, RegisterSpread::Duplicate) => { + add_fsm_regs("std_reg", DUPLICATE_NUM_REG) + } + (RegisterEncoding::OneHot, RegisterSpread::Duplicate) => { + add_fsm_regs("init_one_reg", DUPLICATE_NUM_REG) + } + }; + + (fsms, first_state, fsm_size) + } + + /// Given the state of the FSM, returns the index for the register in `fsms`` + /// that should be queried. + /// A query for each state must read from one of the `num_registers` registers. + /// For `r` registers and `n` states, we split into "buckets" as follows: + /// `{0, ... , n/r - 1} -> reg. @ index 0`, + /// `{n/r, ... , 2n/r - 1} -> reg. @ index 1`, + /// ..., + /// `{(r-1)n/r, ... , n - 1} -> reg. @ index n - 1`. + /// Note that dividing each state by the value `n/r`normalizes the state w.r.t. + /// the FSM register from which it should read. We can then take the floor of this value + /// (or, equivalently, use unsigned integer division) to get this register index. + fn register_to_query( + state: u64, + num_states: u64, + fsms: &[RRC], + ) -> usize { + let num_registers: u64 = fsms.len().try_into().unwrap(); + let reg_to_query: usize = + (state * num_registers / (num_states)).try_into().unwrap(); + reg_to_query + } + /// Implement a given [Schedule] and return the name of the [ir::Group] that /// implements it. fn realize_schedule( self, dump_fsm: bool, fsm_groups: &mut HashSet, - one_hot_cutoff: u64, + fsm_rep: FSMRepresentation, ) -> RRC { + // confirm all states are reachable self.validate(); // build tdcc group @@ -388,195 +501,174 @@ impl<'b, 'a> Schedule<'b, 'a> { )); } - // calculate fsm size and encoding - let final_state = self.last_state(); - let encoding = if final_state <= one_hot_cutoff { - Encoding::OneHot - } else { - Encoding::Binary - }; - - // build necessary primitives dependent on encoding + // build necessary primitives dependent on encoding and spread let signal_on = self.builder.add_constant(1, 1); - let (fsm, first_state, last_state_opt, fsm_size) = match encoding { - Encoding::Binary => { - let fsm_size = get_bit_width_from( - final_state + 1, /* represent 0..final_state */ - ); - structure!(self.builder; - let fsm = prim std_reg(fsm_size); - let last_state = constant(final_state, fsm_size); - let first_state = constant(0, fsm_size); - ); - (fsm, first_state, Some(last_state), fsm_size) - } - Encoding::OneHot => { - let fsm_size = final_state + 1; /* represent 0..final_state */ + let (fsms, first_state, fsm_size) = + Self::build_fsm_infrastructure(self.builder, &fsm_rep); - let fsm = self.builder.add_primitive( - "fsm", - "init_one_reg", - &[fsm_size], - ); - let first_state = self.builder.add_constant(1, fsm_size); - (fsm, first_state, None, fsm_size) - } - }; + // get first fsm register + let fsm1 = fsms.first().expect("first fsm register does not exist"); // Add last state to JSON info let mut states = self.groups_to_states.iter().cloned().collect_vec(); states.push(FSMStateInfo { - id: final_state, - group: Id::new(format!("{}_END", fsm.borrow().name())), + id: fsm_rep.last_state, // check that this register (fsm.0) is the correct one to use + group: Id::new(format!("{}_END", fsm1.borrow().name())), }); // Keep track of groups to FSM state id information for dumping to json fsm_groups.insert(ProfilingInfo::Fsm(FSMInfo { component: self.builder.component.name, - fsm: fsm.borrow().name(), + fsm: fsm1.borrow().name(), group: group.borrow().name(), states, })); - // keep track of used slicers if using one hot encoding - let mut used_slicers = HashMap::new(); + // keep track of used slicers if using one hot encoding. one for each register + let mut used_slicers_vec = + fsms.iter().map(|_| HashMap::new()).collect_vec(); // enable assignments + // the following enable queries; we can decide which register to query for state-dependent assignments + // because we know all registers precisely agree at each cycle group.borrow_mut().assignments.extend( self.enables .into_iter() .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) - .flat_map(|(state, mut assigns)| match encoding { - Encoding::Binary => { - let state_const = - self.builder.add_constant(state, fsm_size); - let state_guard = - guard!(fsm["out"] == state_const["out"]); - assigns.iter_mut().for_each(|asgn| { - asgn.guard.update(|g| g.and(state_guard.clone())) - }); - assigns - } - Encoding::OneHot => { - let state_guard = Self::build_one_hot_query( - self.builder, - &mut used_slicers, - &fsm, - &signal_on, - &state, - &fsm_size, - ); - assigns.iter_mut().for_each(|asgn| { - asgn.guard.update(|g| g.and(state_guard.clone())) - }); - assigns - } + .flat_map(|(state, mut assigns)| { + // find the register from which to query; try to split evenly among registers + let (fsm, used_slicers) = { + let reg_to_query = Self::register_to_query(state, fsm_rep.last_state, &fsms); + ( + fsms.get(reg_to_query).expect("the register at this index does not exist"), + used_slicers_vec.get_mut(reg_to_query).expect( + "the used slicer map at this index does not exist", + ), + )}; + // for every assignment dependent on current fsm state, `&` new guard with existing guard + let state_guard = Self::build_query( + self.builder, + used_slicers, + &fsm_rep, + fsm, + &signal_on, + &state, + &fsm_size, + ); + assigns.iter_mut().for_each(|asgn| { + asgn.guard.update(|g| g.and(state_guard.clone())) + }); + assigns }), ); // transition assignments + // the following updates are meant to ensure agreement between the two + // fsm registers; hence, all registers must be updated if `duplicate` is chosen group.borrow_mut().assignments.extend( self.transitions.into_iter().flat_map(|(s, e, guard)| { - let (end_const, trans_guard) = match encoding { - Encoding::Binary => { - structure!(self.builder; - let end_const = constant(e, fsm_size); - let start_const = constant(s, fsm_size); - ); - let trans_guard = - guard!((fsm["out"] == start_const["out"]) & guard); - - (end_const, trans_guard) - } - Encoding::OneHot => { - let end_constant_value = u64::pow( - 2, - e.try_into().expect("failed to convert to u32"), - ); - - let trans_guard = Self::build_one_hot_query( - self.builder, - &mut used_slicers, - &fsm, - &signal_on, - &s, - &fsm_size, - ); - let end_const = self - .builder - .add_constant(end_constant_value, fsm_size); - - (end_const, trans_guard.and(guard)) - } - }; - let ec_borrow = end_const.borrow(); - vec![ - self.builder.build_assignment( - fsm.borrow().get("in"), - ec_borrow.get("out"), - trans_guard.clone(), - ), - self.builder.build_assignment( - fsm.borrow().get("write_en"), - signal_on.borrow().get("out"), - trans_guard, + // get a transition guard for the first fsm register, and apply it to every fsm register + let state_guard = Self::build_query( + self.builder, + used_slicers_vec.get_mut(0).expect( + "the used slicer map at this index 0 does not exist", ), - ] + &fsm_rep, + fsms.first().expect("register 0 does not exist"), + &signal_on, + &s, + &fsm_size, + ); + + // add transitions for every fsm register to ensure consistency between each + fsms.iter() + .flat_map(|fsm| { + let trans_guard = + state_guard.clone().and(guard.clone()); + let end_const = match fsm_rep.encoding { + RegisterEncoding::Binary => { + self.builder.add_constant(e, fsm_size) + } + RegisterEncoding::OneHot => { + self.builder.add_constant( + u64::pow( + 2, + e.try_into() + .expect("failed to convert to u32"), + ), + fsm_size, + ) + } + }; + let ec_borrow = end_const.borrow(); + vec![ + self.builder.build_assignment( + fsm.borrow().get("in"), + ec_borrow.get("out"), + trans_guard.clone(), + ), + self.builder.build_assignment( + fsm.borrow().get("write_en"), + signal_on.borrow().get("out"), + trans_guard, + ), + ] + }) + .collect_vec() }), ); // done condition for group - let reset_fsm = match last_state_opt { - // binary branch; only binary needs last state constant - Some(last_state) => { - let last_guard = guard!(fsm["out"] == last_state["out"]); - let done_assign = self.builder.build_assignment( - group.borrow().get("done"), - signal_on.borrow().get("out"), - last_guard.clone(), - ); - group.borrow_mut().assignments.push(done_assign); + // arbitrarily look at first fsm register, since all are identical + let first_fsm_last_guard = Self::build_query( + self.builder, + used_slicers_vec + .get_mut(0) + .expect("the used slicer map at this index does not exist"), + &fsm_rep, + fsm1, + &signal_on, + &fsm_rep.last_state, + &fsm_size, + ); - // Cleanup: Add a transition from last state to the first state. - let reset_fsm = build_assignments!(self.builder; - fsm["in"] = last_guard ? first_state["out"]; - fsm["write_en"] = last_guard ? signal_on["out"]; - ); + let done_assign = self.builder.build_assignment( + group.borrow().get("done"), + signal_on.borrow().get("out"), + first_fsm_last_guard.clone(), + ); - reset_fsm.to_vec() - } + group.borrow_mut().assignments.push(done_assign); - // ohe branch does not need last state constant - None => { - let last_guard = Self::build_one_hot_query( + // Cleanup: Add a transition from last state to the first state for each register + let reset_fsms = fsms + .iter() + .enumerate() + .flat_map(|(i, fsm)| { + let fsm_last_guard = Self::build_query( self.builder, - &mut used_slicers, - &fsm, + used_slicers_vec.get_mut(i).expect( + "the used slicer map at this index does not exist", + ), + &fsm_rep, + fsm, &signal_on, - &final_state, + &fsm_rep.last_state, &fsm_size, ); - let done_assign = self.builder.build_assignment( - group.borrow().get("done"), - signal_on.borrow().get("out"), - last_guard.clone(), - ); - group.borrow_mut().assignments.push(done_assign); - // Cleanup: Add a transition from last state to the first state. let reset_fsm = build_assignments!(self.builder; - fsm["in"] = last_guard ? first_state["out"]; - fsm["write_en"] = last_guard ? signal_on["out"]; + fsm["in"] = fsm_last_guard ? first_state["out"]; + fsm["write_en"] = fsm_last_guard ? signal_on["out"]; ); - reset_fsm.to_vec() - } - }; + }) + .collect_vec(); - // extend with conditions to set fsm to initial state + // extend with conditions to set all fsms to initial state self.builder .component .continuous_assignments - .extend(reset_fsm); + .extend(reset_fsms); group } @@ -994,9 +1086,40 @@ pub struct TopDownCompileControl { early_transitions: bool, /// Bookkeeping for FSM ids for groups across all FSMs in the program fsm_groups: HashSet, - /// How many states the dynamic FSM must have before we pick binary encoding over - /// one-hot + /// How many states the dynamic FSM must have before picking binary over one-hot one_hot_cutoff: u64, + /// Number of states the dynamic FSM must have before picking duplicate over single register + duplicate_cutoff: u64, +} + +impl TopDownCompileControl { + /// Given a dynamic schedule and attributes, selects a representation for + /// the finite state machine in hardware. + fn get_representation( + &self, + sch: &Schedule, + attrs: &ir::Attributes, + ) -> FSMRepresentation { + let last_state = sch.last_state(); + FSMRepresentation { + encoding: { + match ( + attrs.has(BoolAttr::OneHot), + last_state <= self.one_hot_cutoff, + ) { + (true, _) | (false, true) => RegisterEncoding::OneHot, + (false, false) => RegisterEncoding::Binary, + } + }, + spread: { + match (last_state + 1) <= self.duplicate_cutoff { + true => RegisterSpread::Single, + false => RegisterSpread::Duplicate, + } + }, + last_state, + } + } } impl ConstructVisitor for TopDownCompileControl { @@ -1014,6 +1137,9 @@ impl ConstructVisitor for TopDownCompileControl { one_hot_cutoff: opts[&"one-hot-cutoff"] .pos_num() .expect("requires non-negative OHE cutoff parameter"), + duplicate_cutoff: opts[&"duplicate-cutoff"] + .pos_num() + .expect("requires non-negative duplicate cutoff parameter"), }) } @@ -1053,10 +1179,16 @@ impl Named for TopDownCompileControl { ), PassOpt::new( "one-hot-cutoff", - "The threshold at and below which a one-hot encoding is used for dynamic group scheduling", + "Threshold at and below which a one-hot encoding is used for dynamic group scheduling", ParseVal::Num(0), PassOpt::parse_num, ), + PassOpt::new( + "duplicate-cutoff", + "Threshold above which the dynamic fsm register is replicated into a second, identical register", + ParseVal::Num(i64::MAX), + PassOpt::parse_num, + ), ] } } @@ -1116,12 +1248,11 @@ impl Visitor for TopDownCompileControl { let mut builder = ir::Builder::new(comp, sigs); let mut sch = Schedule::from(&mut builder); sch.calculate_states_seq(s, self.early_transitions)?; + let fsm_impl = self.get_representation(&sch, &s.attributes); + // Compile schedule and return the group. - let seq_group = sch.realize_schedule( - self.dump_fsm, - &mut self.fsm_groups, - self.one_hot_cutoff, - ); + let seq_group = + sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups, fsm_impl); // Add NODE_ID to compiled group. let mut en = ir::Control::enable(seq_group); @@ -1147,11 +1278,9 @@ impl Visitor for TopDownCompileControl { // Compile schedule and return the group. sch.calculate_states_if(i, self.early_transitions)?; - let if_group = sch.realize_schedule( - self.dump_fsm, - &mut self.fsm_groups, - self.one_hot_cutoff, - ); + let fsm_impl = self.get_representation(&sch, &i.attributes); + let if_group = + sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups, fsm_impl); // Add NODE_ID to compiled group. let mut en = ir::Control::enable(if_group); @@ -1175,13 +1304,11 @@ impl Visitor for TopDownCompileControl { let mut builder = ir::Builder::new(comp, sigs); let mut sch = Schedule::from(&mut builder); sch.calculate_states_while(w, self.early_transitions)?; + let fsm_impl = self.get_representation(&sch, &w.attributes); // Compile schedule and return the group. - let if_group = sch.realize_schedule( - self.dump_fsm, - &mut self.fsm_groups, - self.one_hot_cutoff, - ); + let if_group = + sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups, fsm_impl); // Add NODE_ID to compiled group. let mut en = ir::Control::enable(if_group); @@ -1229,10 +1356,11 @@ impl Visitor for TopDownCompileControl { _ => { let mut sch = Schedule::from(&mut builder); sch.calculate_states(con, self.early_transitions)?; + let fsm_impl = self.get_representation(&sch, &s.attributes); sch.realize_schedule( self.dump_fsm, &mut self.fsm_groups, - self.one_hot_cutoff, + fsm_impl, ) } }; @@ -1299,16 +1427,17 @@ impl Visitor for TopDownCompileControl { _comps: &[ir::Component], ) -> VisResult { let control = Rc::clone(&comp.control); - // IRPrinter::write_control(&control.borrow(), 0, &mut std::io::stderr()); + let attrs = comp.attributes.clone(); + let mut builder = ir::Builder::new(comp, sigs); let mut sch = Schedule::from(&mut builder); + // Add assignments for the final states sch.calculate_states(&control.borrow(), self.early_transitions)?; - let comp_group = sch.realize_schedule( - self.dump_fsm, - &mut self.fsm_groups, - self.one_hot_cutoff, - ); + let fsm_impl = self.get_representation(&sch, &attrs); + let comp_group = + sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups, fsm_impl); + if let Some(json_out_file) = &self.dump_fsm_json { let _ = serde_json::to_writer_pretty( json_out_file.get_write(), diff --git a/calyx-opt/src/traversal/construct.rs b/calyx-opt/src/traversal/construct.rs index e9b64b5902..65042598a2 100644 --- a/calyx-opt/src/traversal/construct.rs +++ b/calyx-opt/src/traversal/construct.rs @@ -12,6 +12,8 @@ pub enum ParseVal { Bool(bool), /// A number option. Num(i64), + /// A string option. + String(String), /// A list of values. List(Vec), /// An output stream (stdout, stderr, file name) @@ -33,6 +35,13 @@ impl ParseVal { *n } + pub fn string(&self) -> String { + let ParseVal::String(s) = self else { + panic!("Expected String, got {self}"); + }; + s.clone() + } + pub fn pos_num(&self) -> Option { let n = self.num(); if n < 0 { @@ -86,6 +95,7 @@ impl std::fmt::Display for ParseVal { match self { ParseVal::Bool(b) => write!(f, "{b}"), ParseVal::Num(n) => write!(f, "{n}"), + ParseVal::String(s) => write!(f, "{s}"), ParseVal::List(l) => { write!(f, "[")?; for (i, e) in l.iter().enumerate() { @@ -166,6 +176,11 @@ impl PassOpt { s.parse::().ok().map(ParseVal::Num) } + /// Parse a String from a string. + pub fn parse_string(s: &str) -> Option { + Some(ParseVal::String(s.to_string())) + } + /// Parse a list of numbers from a string. pub fn parse_num_list(s: &str) -> Option { Self::parse_list(s, Self::parse_num) diff --git a/runt.toml b/runt.toml index 8a4be975b6..3e7f19d259 100644 --- a/runt.toml +++ b/runt.toml @@ -245,6 +245,27 @@ fud exec --from calyx --to jq \ """ timeout = 120 +[[tests]] +name = "correctness dynamic fsm register duplication" +paths = [ + "tests/correctness/*.futil", + "tests/correctness/ref-cells/*.futil", + "tests/correctness/sync/*.futil", + "tests/correctness/static-interface/*.futil", +] +cmd = """ +fud exec --from calyx --to jq \ + --through verilog \ + --through dat \ + -s calyx.exec './target/debug/calyx' \ + -s calyx.flags '-x tdcc:spread=duplicate -d static-promotion' \ + -s verilog.cycle_limit 500 \ + -s verilog.data {}.data \ + -s jq.expr ".memories" \ + {} -q +""" +timeout = 120 + [[tests]] name = "correctness static timing" paths = [