Skip to content

Commit

Permalink
comb-prop: Disable rewrites when wire output is used (#1656)
Browse files Browse the repository at this point in the history
  • Loading branch information
rachitnigam authored Aug 11, 2023
1 parent 22742f8 commit 4970ba5
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Unreleased
- Don't require `@clk` and `@reset` ports in `comb` components
- `inline` pass supports inlining `ref` cells
- `comb-prop`: disable rewrite from `wire.in = port` when the output of a wire is read.


## 0.4.0
Expand Down
20 changes: 20 additions & 0 deletions calyx-ir/src/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,26 @@ pub struct Assignment<T> {
}

impl<T> Assignment<T> {
/// Build a new unguarded assignment
pub fn new(dst: RRC<Port>, src: RRC<Port>) -> Self {
assert!(
dst.borrow().direction == Direction::Input,
"{} is not in input port",
dst.borrow().canonical()
);
assert!(
src.borrow().direction == Direction::Output,
"{} is not in output port",
src.borrow().canonical()
);
Self {
dst,
src,
guard: Box::new(Guard::True),
attributes: Attributes::default(),
}
}

/// Apply function `f` to each port contained within the assignment and
/// replace the port with the generated value if not None.
pub fn for_each_port<F>(&mut self, mut f: F)
Expand Down
228 changes: 167 additions & 61 deletions calyx-opt/src/passes/comb_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,87 @@ struct WireRewriter {
}

impl WireRewriter {
/// Insert into rewrite map. If `v` is in current `rewrites`, then insert `k` -> `rewrites[v]`.
/// Panics if there is already a mapping for `k`.
pub fn insert(
// If the destination is a wire, then we have something like:
// ```
// wire.in = c.out;
// ```
// Which means all instances of `wire.out` can be replaced with `c.out` because the
// wire is being used to forward values from `c.out`.
pub fn insert_src_rewrite(
&mut self,
wire: RRC<ir::Cell>,
src: RRC<ir::Port>,
) {
let wire_out = wire.borrow().get("out");
log::debug!(
"src rewrite: {} -> {}",
wire_out.borrow().canonical(),
src.borrow().canonical(),
);
let old = self.insert(wire_out, Rc::clone(&src));
assert!(
old.is_none(),
"Attempting to add multiple sources to a wire"
);
}

// If the source is a wire, we have something like:
// ```
// c.in = wire.out;
// ```
// Which means all instances of `wire.in` can be replaced with `c.in` because the wire
// is being used to unconditionally forward values.
pub fn insert_dst_rewrite(
&mut self,
wire: RRC<ir::Cell>,
dst: RRC<ir::Port>,
) {
let wire_in = wire.borrow().get("in");
log::debug!(
"dst rewrite: {} -> {}",
wire_in.borrow().canonical(),
dst.borrow().canonical(),
);
let old_v = self.insert(Rc::clone(&wire_in), dst);

// If the insertion process found an old key, we have something like:
// ```
// x.in = wire.out;
// y.in = wire.out;
// ```
// This means that `wire` is being used to forward values to many components and a
// simple inlining will not work.
if old_v.is_some() {
self.remove(wire_in);
}

// No forwading generated because the wire is used in dst position
}

/// Insert into rewrite map. If `v` is in current `rewrites`, then insert `k` -> `rewrites[v]`
/// and returns the previous rewrite if any.
fn insert(
&mut self,
from: RRC<ir::Port>,
to: RRC<ir::Port>,
) -> Option<RRC<ir::Port>> {
let from_idx = from.borrow().canonical();
self.rewrites.insert(from_idx, to)
let old = self.rewrites.insert(from_idx, to);
if log::log_enabled!(log::Level::Debug) {
if let Some(ref old) = old {
log::debug!(
"Previous rewrite: {} -> {}",
from.borrow().canonical(),
old.borrow().canonical()
);
}
}
old
}

// Removes the mapping associated with the key.
pub fn remove(&mut self, from: RRC<ir::Port>) {
log::debug!("Removing rewrite for `{}'", from.borrow().canonical());
let from_idx = from.borrow().canonical();
self.rewrites.remove(&from_idx);
}
Expand Down Expand Up @@ -123,7 +191,7 @@ pub struct CombProp {
/// Disable automatic removal of some dead assignments needed for correctness and instead mark
/// them with @dead.
/// NOTE: if this is enabled, the pass will not remove obviously conflicting assignments.
do_not_eliminate: bool,
no_eliminate: bool,
}

impl ConstructVisitor for CombProp {
Expand All @@ -133,7 +201,7 @@ impl ConstructVisitor for CombProp {
{
let opts = Self::get_opts(ctx);
Ok(CombProp {
do_not_eliminate: opts[0],
no_eliminate: opts[0],
})
}

Expand All @@ -160,26 +228,82 @@ impl Named for CombProp {
}

impl CombProp {
/// Predicate for removing an assignment
#[inline]
fn remove_predicate<T>(
rewritten: &[RRC<ir::Port>],
assign: &ir::Assignment<T>,
) -> bool
where
T: Clone + Eq + ToString,
{
let out = rewritten.iter().any(|v| Rc::ptr_eq(v, &assign.dst));
if log::log_enabled!(log::Level::Debug) && out {
log::debug!("Removing: {}", ir::Printer::assignment_to_str(assign));
}
out
}

/// Mark assignments for removal
fn remove_rewritten(
&self,
rewritten: Vec<&RRC<ir::Port>>,
rewritten: &[RRC<ir::Port>],
comp: &mut ir::Component,
) {
log::debug!(
"Rewritten: {}",
rewritten
.iter()
.map(|p| format!("{}", p.borrow().canonical()))
.collect::<Vec<_>>()
.join(", ")
);
// Remove writes to all the ports that show up in write position
if self.do_not_eliminate {
if self.no_eliminate {
// If elimination is disabled, mark the assignments with the @dead attribute.
for assign in &mut comp.continuous_assignments {
if rewritten.iter().any(|v| Rc::ptr_eq(v, &assign.dst)) {
if Self::remove_predicate(rewritten, assign) {
assign.attributes.insert(ir::InternalAttr::DEAD, 1)
}
}
} else {
comp.continuous_assignments.retain_mut(|assign| {
!rewritten.iter().any(|v| Rc::ptr_eq(v, &assign.dst))
!Self::remove_predicate(rewritten, assign)
});
}
}

fn parent_is_wire(parent: &ir::PortParent) -> bool {
match parent {
ir::PortParent::Cell(cell_wref) => {
let cr = cell_wref.upgrade();
let cell = cr.borrow();
cell.is_primitive(Some("std_wire"))
}
ir::PortParent::Group(_) => false,
ir::PortParent::StaticGroup(_) => false,
}
}

fn disable_rewrite<T>(
assign: &mut ir::Assignment<T>,
rewrites: &mut WireRewriter,
) {
if assign.guard.is_true() {
return;
}
assign.for_each_port(|pr| {
let p = pr.borrow();
if p.direction == ir::Direction::Output
&& Self::parent_is_wire(&p.parent)
{
let cell = p.cell_parent();
rewrites.remove(cell.borrow().get("in"));
}
// Never change the port
None
});
}
}

impl Visitor for CombProp {
Expand All @@ -191,76 +315,58 @@ impl Visitor for CombProp {
) -> VisResult {
let mut rewrites = WireRewriter::default();

let parent_is_wire = |parent: &ir::PortParent| -> bool {
match parent {
ir::PortParent::Cell(cell_wref) => {
let cr = cell_wref.upgrade();
let cell = cr.borrow();
cell.is_primitive(Some("std_wire"))
}
ir::PortParent::Group(_) => false,
ir::PortParent::StaticGroup(_) => false,
}
};

for assign in &mut comp.continuous_assignments {
// Skip conditional continuous assignments
// Cannot add rewrites for conditional statements
if !assign.guard.is_true() {
continue;
}
// If the destination is a wire, then we have something like:
// ```
// wire.in = c.out;
// ```
// Which means all instances of `wire.out` can be replaced with `c.out` because the
// wire is being used to forward values from `c.out`.

let dst = assign.dst.borrow();
if parent_is_wire(&dst.parent) {
rewrites.insert(
dst.cell_parent().borrow().get("out"),
if Self::parent_is_wire(&dst.parent) {
rewrites.insert_src_rewrite(
dst.cell_parent(),
Rc::clone(&assign.src),
);
}

// If the source is a wire, we have something like:
// ```
// c.in = wire.out;
// ```
// Which means all instances of `wire.in` can be replaced with `c.in` because the wire
// is being used to unconditionally forward values.
let src = assign.src.borrow();
if parent_is_wire(&src.parent) {
let port = src.cell_parent().borrow().get("in");
let old_v =
rewrites.insert(Rc::clone(&port), Rc::clone(&assign.dst));

// If the insertion process found an old key, we have something like:
// ```
// x.in = wire.out;
// y.in = wire.out;
// ```
// This means that `wire` is being used to forward values to many components and a
// simple inlining will not work.
if old_v.is_some() {
rewrites.remove(port);
}
if Self::parent_is_wire(&src.parent) {
rewrites.insert_dst_rewrite(
src.cell_parent(),
Rc::clone(&assign.dst),
);
}
}

// Disable all rewrites:
// If the statement uses a wire output (w.out) as a source, we
// cannot rewrite the wire's input (w.in) uses
comp.for_each_assignment(|assign| {
Self::disable_rewrite(assign, &mut rewrites)
});
comp.for_each_static_assignment(|assign| {
Self::disable_rewrite(assign, &mut rewrites)
});

// Rewrite assignments
// Make the set of rewrites consistent and transform into map
let rewrites: ir::rewriter::PortRewriteMap = rewrites.into();
let rewritten = rewrites.values().collect_vec();
self.remove_rewritten(rewritten, comp);
let rewritten = rewrites.values().cloned().collect_vec();
self.remove_rewritten(&rewritten, comp);

comp.for_each_assignment(|assign| {
assign.for_each_port(|port| {
rewrites.get(&port.borrow().canonical()).cloned()
})
if !assign.attributes.has(ir::InternalAttr::DEAD) {
assign.for_each_port(|port| {
rewrites.get(&port.borrow().canonical()).cloned()
})
}
});
comp.for_each_static_assignment(|assign| {
assign.for_each_port(|port| {
rewrites.get(&port.borrow().canonical()).cloned()
})
if !assign.attributes.has(ir::InternalAttr::DEAD) {
assign.for_each_port(|port| {
rewrites.get(&port.borrow().canonical()).cloned()
})
}
});

let cell_rewrites = HashMap::new();
Expand Down
18 changes: 18 additions & 0 deletions tests/passes/comb-prop/wire-dst-read.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import "primitives/compile.futil";
component main<"toplevel"=1>(@go go: 1, @clk clk: 1, @reset reset: 1) -> (out: 32, @done done: 1) {
cells {
opt = std_wire(32);
r = std_reg(32);
r0 = std_reg(1);
r1 = std_reg(1);
r2 = std_reg(1);
}
wires {
r.in = opt.out;
r.write_en = 1'd1;
opt.in = r0.out ? 32'd10;
opt.in = r1.out ? 32'd20;
out = r2.out ? opt.out;
}
control {}
}
22 changes: 22 additions & 0 deletions tests/passes/comb-prop/wire-dst-read.futil
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// -p well-formed -p comb-prop
import "primitives/compile.futil";
component main<"toplevel"=1>() -> (out: 32) {
cells {
opt = std_wire(32);
r = std_reg(32);
// Stable conditions
r0 = std_reg(1);
r1 = std_reg(1);
r2 = std_reg(1);
}
wires {
r.in = opt.out;
r.write_en = 1'd1;

opt.in = r0.out ? 32'd10;
opt.in = r1.out ? 32'd20;

out = r2.out ? opt.out;
}
control {}
}

0 comments on commit 4970ba5

Please sign in to comment.