Skip to content

Commit

Permalink
PIFO in Calyx (#1625)
Browse files Browse the repository at this point in the history
  • Loading branch information
anshumanmohan authored Aug 15, 2023
1 parent 4970ba5 commit 0bd028c
Show file tree
Hide file tree
Showing 7 changed files with 547 additions and 195 deletions.
25 changes: 15 additions & 10 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,6 @@ def seq_mem_d1(
name, ast.Stdlib.seq_mem_d1(bitwidth, len, idx_size), is_external, is_ref
)

def is_seq_mem_d1(self, cell: CellBuilder) -> bool:
"""Check if the cell is a SeqMemD1 cell."""
return (
isinstance(cell._cell.comp, ast.CompInst)
and cell._cell.comp.name == "seq_mem_d1"
)

def add(self, name: str, size: int, signed=False) -> CellBuilder:
"""Generate a StdAdd cell."""
self.prog.import_("primitives/binary_operators.futil")
Expand Down Expand Up @@ -319,6 +312,10 @@ def and_(self, name: str, size: int) -> CellBuilder:
"""Generate a StdAnd cell."""
return self.cell(name, ast.Stdlib.op("and", size, False))

def not_(self, name: str, size: int) -> CellBuilder:
"""Generate a StdNot cell."""
return self.cell(name, ast.Stdlib.op("not", size, False))

def pipelined_mult(self, name: str) -> CellBuilder:
"""Generate a pipelined multiplier."""
self.prog.import_("primitives/pipelined.futil")
Expand Down Expand Up @@ -620,13 +617,21 @@ def port(self, name: str) -> ExprBuilder:
"""Build a port access expression."""
return ExprBuilder(ast.Atom(ast.CompPort(self._cell.id, name)))

def is_mem_d1(self) -> bool:
"""Check if the cell is a StdMemD1 cell."""
def is_primitive(self, prim_name) -> bool:
"""Check if the cell is an instance of the primitive {prim_name}."""
return (
isinstance(self._cell.comp, ast.CompInst)
and self._cell.comp.id == "std_mem_d1"
and self._cell.comp.id == prim_name
)

def is_std_mem_d1(self) -> bool:
"""Check if the cell is a StdMemD1 cell."""
return self.is_primitive("std_mem_d1")

def is_seq_mem_d1(self) -> bool:
"""Check if the cell is a SeqMemD1 cell."""
return self.is_primitive("seq_mem_d1")

@classmethod
def unwrap_id(cls, obj):
if isinstance(obj, cls):
Expand Down
95 changes: 51 additions & 44 deletions calyx-py/calyx/builder_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def insert_sub(comp: cb.ComponentBuilder, left, right, cellname, width):
return insert_comb_group(comp, left, right, sub_cell, f"{cellname}_group")


def insert_bitwise_flip_reg(comp: cb.ComponentBuilder, reg, cellname, width):
"""Inserts wiring into component {comp} to bitwise-flip the contents of {reg}.
Returns a handle to the group that does this.
"""
not_cell = comp.not_(cellname, width)
with comp.group(f"{cellname}_group") as not_group:
not_cell.in_ = reg.out
reg.write_en = 1
reg.in_ = not_cell.out
not_group.done = reg.done
return not_group


def insert_incr(comp: cb.ComponentBuilder, reg, cellname, val=1):
"""Inserts wiring into component {comp} to increment register {reg} by {val}.
1. Within component {comp}, creates a group called {cellname}_group.
Expand Down Expand Up @@ -145,13 +159,27 @@ def insert_decr(comp: cb.ComponentBuilder, reg, cellname, val=1):
return decr_group


def mem_load(comp: cb.ComponentBuilder, mem, i, reg, group):
"""Loads a value from one memory into a register.
def insert_reg_store(comp: cb.ComponentBuilder, reg, val, group):
"""Stores a value in a register.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, sets the register {reg} to {val}.
3. Returns the group that does this.
"""
with comp.group(group) as reg_grp:
reg.in_ = val
reg.write_en = 1
reg_grp.done = reg.done
return reg_grp


def mem_load_std_d1(comp: cb.ComponentBuilder, mem, i, reg, group):
"""Loads a value from one memory (std_d1) into a register.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, reads from memory {mem} at address {i}.
3. Writes the value into register {reg}.
4. Returns the group that does this.
"""
assert mem.is_std_mem_d1()
with comp.group(group) as load_grp:
mem.addr0 = i
reg.write_en = 1
Expand All @@ -160,13 +188,14 @@ def mem_load(comp: cb.ComponentBuilder, mem, i, reg, group):
return load_grp


def mem_store(comp: cb.ComponentBuilder, mem, i, val, group):
"""Stores a value from one memory into another.
def mem_store_std_d1(comp: cb.ComponentBuilder, mem, i, val, group):
"""Stores a value into a (std_d1) memory.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, reads from {val}.
3. Writes the value into memory {mem} at address i.
4. Returns the group that does this.
"""
assert mem.is_std_mem_d1()
with comp.group(group) as store_grp:
mem.addr0 = i
mem.write_en = 1
Expand All @@ -175,36 +204,33 @@ def mem_store(comp: cb.ComponentBuilder, mem, i, val, group):
return store_grp


def insert_reg_store(comp: cb.ComponentBuilder, reg, val, group):
"""Stores a value in a register.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, sets the register {reg} to {val}.
3. Returns the group that does this.
"""
with comp.group(group) as reg_grp:
reg.in_ = val
reg.write_en = 1
reg_grp.done = reg.done
return reg_grp


def mem_read_seqd1(comp: cb.ComponentBuilder, mem, i, group):
def mem_read_seq_d1(comp: cb.ComponentBuilder, mem, i, group):
"""Given a seq_mem_d1, reads from memory at address i.
Note that this does not write the value anywhere.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, reads from memory {mem} at address {i},
thereby "latching" the value.
3. Returns the group that does this.
"""
assert mem.is_seq_mem_d1
assert mem.is_seq_mem_d1()
with comp.group(group) as read_grp:
mem.addr0 = i
mem.read_en = 1
read_grp.done = mem.read_done
return read_grp


def mem_write_seqd1_to_reg(comp: cb.ComponentBuilder, mem, reg, group):
def mem_write_seq_d1_to_reg(comp: cb.ComponentBuilder, mem, reg, group):
"""Given a seq_mem_d1 that is already assumed to have a latched value,
reads the latched value and writes it to a register.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, reads from memory {mem}.
3. Writes the value into register {reg}.
4. Returns the group that does this.
"""
assert mem.is_seq_mem_d1
assert mem.is_seq_mem_d1()
with comp.group(group) as write_grp:
reg.write_en = 1
reg.in_ = mem.read_data
Expand All @@ -213,13 +239,14 @@ def mem_write_seqd1_to_reg(comp: cb.ComponentBuilder, mem, reg, group):


def mem_store_seq_d1(comp: cb.ComponentBuilder, mem, i, val, group):
"""Stores a value from one memory into another.
"""Given a seq_mem_d1, stores a value into memory at address i.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, reads from {val}.
3. Writes the value into memory {mem} at address i.
4. Returns the group that does this.
"""
assert mem.is_seq_mem_d1
assert mem.is_seq_mem_d1()
with comp.group(group) as store_grp:
mem.addr0 = i
mem.write_en = 1
Expand All @@ -228,34 +255,14 @@ def mem_store_seq_d1(comp: cb.ComponentBuilder, mem, i, val, group):
return store_grp


def reg_swap(comp: cb.ComponentBuilder, a, b, group):
"""Swaps the values of two registers.
1. Within component {comp}, creates a group called {group}.
2. Reads the value of {a} into a temporary register.
3. Writes the value of {b} into {a}.
4. Writes the value of the temporary register into {b}.
5. Returns the group that does this.
"""
with comp.group(group) as swap_grp:
tmp = comp.reg("tmp", 1)
tmp.write_en = 1
tmp.in_ = a.out
a.write_en = 1
a.in_ = b.out
b.write_en = 1
b.in_ = tmp.out
swap_grp.done = b.done
return swap_grp


def insert_mem_load_to_mem(comp: cb.ComponentBuilder, mem, i, ans, j, group):
"""Loads a value from one std_mem_d1 memory into another.
1. Within component {comp}, creates a group called {group}.
2. Within {group}, reads from memory {mem} at address {i}.
3. Writes the value into memory {ans} at address {j}.
4. Returns the group that does this.
"""
assert mem.is_mem_d1() and ans.is_mem_d1()
assert mem.is_std_mem_d1() and ans.is_std_mem_d1()
with comp.group(group) as load_grp:
mem.addr0 = i
ans.write_en = 1
Expand Down
141 changes: 141 additions & 0 deletions calyx-py/calyx/queue_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# pylint: disable=import-error
import calyx.builder as cb
import calyx.builder_util as util

MAX_CMDS = 15


def insert_raise_err_if_i_eq_max_cmds(prog):
"""Inserts a the component `raise_err_if_i_eq_MAX_CMDS` into the program.
It has:
- one input, `i`.
- one ref register, `err`.
If `i` equals MAX_CMDS, it raises the `err` flag.
"""
raise_err_if_i_eq_max_cmds: cb.ComponentBuilder = prog.component(
"raise_err_if_i_eq_MAX_CMDS"
)
i = raise_err_if_i_eq_max_cmds.input("i", 32)
err = raise_err_if_i_eq_max_cmds.reg("err", 1, is_ref=True)

i_eq_max_cmds = util.insert_eq(
raise_err_if_i_eq_max_cmds, i, MAX_CMDS, "i_eq_MAX_CMDS", 32
)
raise_err = util.insert_reg_store(raise_err_if_i_eq_max_cmds, err, 1, "raise_err")

raise_err_if_i_eq_max_cmds.control += [
cb.if_(
i_eq_max_cmds[0].out,
i_eq_max_cmds[1],
raise_err,
)
]

return raise_err_if_i_eq_max_cmds


def insert_main(prog, queue):
"""Inserts the component `main` into the program.
This will be used to `invoke` the component `queue` and feed it a list of commands.
"""
main: cb.ComponentBuilder = prog.component("main")

# The user-facing interface of the `main` component is:
# - a list of commands (the input)
# where each command is a 32-bit unsigned integer, with the following format:
# `0`: pop
# any other value: push that value
# - a list of answers (the output).
#
# The user-facing interface of the `queue` component is:
# - one input, `cmd`.
# where each command is a 32-bit unsigned integer, with the following format:
# `0`: pop
# any other value: push that value
# - one ref register, `ans`, into which the result of a pop is written.
# - one ref register, `err`, which is raised if an error occurs.

commands = main.seq_mem_d1("commands", 32, MAX_CMDS, 32, is_external=True)
ans_mem = main.seq_mem_d1("ans_mem", 32, 10, 32, is_external=True)

# The two components we'll use:
queue = main.cell("myqueue", queue)
raise_err_if_i_eq_max_cmds = main.cell(
"raise_err_if_i_eq_MAX_CMDS", insert_raise_err_if_i_eq_max_cmds(prog)
)

# We will use the `invoke` method to call the `queue` component.
# The queue component takes two inputs by reference and one input directly.
# The two `ref` inputs:
err = main.reg("err", 1) # A flag to indicate an error
ans = main.reg("ans", 32) # A memory to hold the answer of a pop

# We will set up a while loop that runs over the command list, relaying
# the commands to the `queue` component.
# It will run until the `err` flag is raised by the `queue` component.

i = main.reg("i", 32) # The index of the command we're currently processing
j = main.reg("j", 32) # The index on the answer-list we'll write to
cmd = main.reg("command", 32) # The command we're currently processing

incr_i = util.insert_incr(main, i, "incr_i") # i++
incr_j = util.insert_incr(main, j, "incr_j") # j++
err_eq_0 = util.insert_eq(main, err.out, 0, "err_eq_0", 1) # is `err` flag down?
cmd_eq_0 = util.insert_eq(main, cmd.out, 0, "cmd_eq_0", 32) # cmd == 0
cmd_neq_0 = util.insert_neq(
main, cmd.out, cb.const(32, 0), "cmd_neq_0", 32
) # cmd != 0

read_cmd = util.mem_read_seq_d1(main, commands, i.out, "read_cmd_phase1")
write_cmd_to_reg = util.mem_write_seq_d1_to_reg(
main, commands, cmd, "write_cmd_phase2"
)

write_ans = util.mem_store_seq_d1(main, ans_mem, j.out, ans.out, "write_ans")

main.control += [
cb.while_(
err_eq_0[0].out,
err_eq_0[1], # Run while the `err` flag is down
[
read_cmd, # Read `commands[i]`
write_cmd_to_reg, # Write it to `cmd`
cb.par( # Now, in parallel, act based on the value of `cmd`
cb.if_(
# Is this a pop?
cmd_eq_0[0].out,
cmd_eq_0[1],
[ # A pop
cb.invoke( # First we call pop
queue,
in_cmd=cmd.out,
ref_ans=ans,
ref_err=err,
),
# AM: my goal is that,
# if err flag comes back raised,
# we do not perform this write or this incr_j
write_ans,
incr_j,
],
),
cb.if_( # Is this a push?
cmd_neq_0[0].out,
cmd_neq_0[1],
cb.invoke( # A push
queue,
in_cmd=cmd.out,
ref_ans=ans,
ref_err=err,
),
),
),
incr_i, # Increment the command index
cb.invoke( # If i = MAX_CMDS, raise error flag
raise_err_if_i_eq_max_cmds, in_i=i.out, ref_err=err
), # AM: hella hacky
],
),
]
Loading

0 comments on commit 0bd028c

Please sign in to comment.