Skip to content

Commit

Permalink
Systolic: modern eDSL features (#2036)
Browse files Browse the repository at this point in the history
* Tidy one file

* Some easy cleanup in gen-systolic

* Massage away one CompInst

* Massage away all use of py_ast

* A litte cleanup

* No manual imports

* Return type annotation

* Revert change to unassociated file

* Another return type
  • Loading branch information
anshumanmohan authored May 9, 2024
1 parent 7b9ca1c commit cf8224e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 53 deletions.
55 changes: 26 additions & 29 deletions frontends/systolic-lang/gen-systolic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#!/usr/bin/env python3
import calyx.builder as cb
from systolic_arg_parser import SystolicConfiguration
from calyx import py_ast
from calyx.utils import bits_needed
from gen_array_component import (
create_systolic_array,
BITWIDTH,
SYSTOLIC_ARRAY_COMP,
NAME_SCHEME,
)
from gen_post_op import (
Expand All @@ -15,19 +13,15 @@
leaky_relu_post_op,
relu_dynamic_post_op,
OUT_MEM,
DEFAULT_POST_OP,
RELU_POST_OP,
LEAKY_RELU_POST_OP,
RELU_DYNAMIC_POST_OP,
)

# Dict that maps command line arguments (e.g., "leaky-relu") to component names
# and function that creates them.
POST_OP_DICT = {
None: (DEFAULT_POST_OP, default_post_op),
"leaky-relu": (LEAKY_RELU_POST_OP, leaky_relu_post_op),
"relu": (RELU_POST_OP, relu_post_op),
"relu-dynamic": (RELU_DYNAMIC_POST_OP, relu_dynamic_post_op),
None: default_post_op,
"leaky-relu": leaky_relu_post_op,
"relu": relu_post_op,
"relu-dynamic": relu_dynamic_post_op,
}


Expand Down Expand Up @@ -60,7 +54,7 @@ def create_mem_connections(
)


def build_main(prog, config: SystolicConfiguration, post_op_component_name):
def build_main(prog, config: SystolicConfiguration, comp_unit, postop_comp):
"""
Build the main component.
It basically connects the ports of the systolic component and post_op component
Expand All @@ -73,12 +67,8 @@ def build_main(prog, config: SystolicConfiguration, post_op_component_name):
config.left_depth,
)
main = prog.component("main")
systolic_array = main.cell(
"systolic_array_component", py_ast.CompInst(SYSTOLIC_ARRAY_COMP, [])
)
post_op = main.cell(
"post_op_component", py_ast.CompInst(post_op_component_name, [])
)
systolic_array = main.cell("systolic_array_component", comp_unit)
post_op = main.cell("post_op_component", postop_comp)
# Connections contains the RTL-like connections between the ports of
# systolic_array_comp and the post_op.
# Also connects the input memories to the systolic_array_comp and
Expand All @@ -87,7 +77,11 @@ def build_main(prog, config: SystolicConfiguration, post_op_component_name):
# Connect input memories to systolic_array
for r in range(top_length):
connections += create_mem_connections(
main, systolic_array, f"t{r}", top_depth, read_mem=True
main,
systolic_array,
f"t{r}",
top_depth,
read_mem=True,
)
for c in range(left_length):
connections += create_mem_connections(
Expand All @@ -103,7 +97,11 @@ def build_main(prog, config: SystolicConfiguration, post_op_component_name):
for i in range(left_length):
# Connect output memory to post op. want to write to this memory.
connections += create_mem_connections(
main, post_op, OUT_MEM + f"_{i}", top_length, read_mem=False
main,
post_op,
OUT_MEM + f"_{i}",
top_length,
read_mem=False,
)
# Connect systolic array to post op
connections += cb.build_connections(
Expand Down Expand Up @@ -131,28 +129,26 @@ def build_main(prog, config: SystolicConfiguration, post_op_component_name):
systolic_done_reg.write_en = systolic_array.done @ 1
systolic_done_reg.in_ = systolic_array.done @ 1
systolic_done_wire.in_ = (systolic_array.done | systolic_done_reg.out) @ 1
systolic_array.go = ~systolic_done_wire.out @ py_ast.ConstantPort(1, 1)
systolic_array.depth = py_ast.ConstantPort(BITWIDTH, left_depth)
systolic_array.go = ~systolic_done_wire.out @ cb.HI
systolic_array.depth = cb.const(BITWIDTH, left_depth)

# Triggering post_op component.
post_op.go = py_ast.ConstantPort(1, 1)
post_op.go = cb.HI
# Group is done when post_op is done.
g.done = post_op.computation_done

main.control = py_ast.Enable("perform_computation")
main.control += g


if __name__ == "__main__":
systolic_config = SystolicConfiguration()
systolic_config.parse_arguments()
# Building the main component
prog = cb.Builder()
create_systolic_array(prog, systolic_config)
comp_unit_inserted = create_systolic_array(prog, systolic_config)
if systolic_config.post_op in POST_OP_DICT.keys():
post_op_component_name, component_building_func = POST_OP_DICT[
systolic_config.post_op
]
component_building_func(prog, config=systolic_config)
component_building_func = POST_OP_DICT[systolic_config.post_op]
postop_comp_inserted = component_building_func(prog, config=systolic_config)
else:
raise ValueError(
f"{systolic_config.post_op} not supported as a post op. \
Expand All @@ -163,6 +159,7 @@ def build_main(prog, config: SystolicConfiguration, post_op_component_name):
build_main(
prog,
config=systolic_config,
post_op_component_name=post_op_component_name,
comp_unit=comp_unit_inserted,
postop_comp=postop_comp_inserted,
)
prog.program.emit()
6 changes: 5 additions & 1 deletion frontends/systolic-lang/gen_array_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def counter():
return py_ast.SeqComp(stmts=control), source_map


def create_systolic_array(prog: cb.Builder, config: SystolicConfiguration):
def create_systolic_array(
prog: cb.Builder, config: SystolicConfiguration
) -> cb.ComponentBuilder:
"""
top_length: Number of PEs in each row.
top_depth: Number of elements processed by each PE in a row.
Expand Down Expand Up @@ -430,3 +432,5 @@ def create_systolic_array(prog: cb.Builder, config: SystolicConfiguration):
control, source_map = generate_control(computational_unit, config, schedule)
computational_unit.control = control
prog.program.meta = source_map

return computational_unit
24 changes: 10 additions & 14 deletions frontends/systolic-lang/gen_pe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import calyx.builder as cb
from calyx import py_ast

# Global constant for the current bitwidth.
BITWIDTH = 32
Expand All @@ -19,28 +18,25 @@ def pe(prog: cb.Builder):
yet.
"""
comp = prog.component(name=PE_NAME, latency=1)
comp.input("top", BITWIDTH)
comp.input("left", BITWIDTH)
comp.input("mul_ready", 1)
top = comp.input("top", BITWIDTH)
left = comp.input("left", BITWIDTH)
mul_ready = comp.input("mul_ready", 1)
comp.output("out", BITWIDTH)
acc = comp.reg(BITWIDTH, "acc")
add = comp.fp_sop("adder", "add", BITWIDTH, INTWIDTH, FRACWIDTH)
mul = comp.pipelined_fp_smult("mul", BITWIDTH, INTWIDTH, FRACWIDTH)

this = comp.this()
with comp.static_group("do_add", 1):
with comp.static_group("do_add", 1) as do_add:
add.left = acc.out
add.right = mul.out
acc.in_ = add.out
acc.write_en = this.mul_ready
acc.write_en = mul_ready

with comp.static_group("do_mul", 1):
mul.left = this.top
mul.right = this.left

par = py_ast.StaticParComp([py_ast.Enable("do_add"), py_ast.Enable("do_mul")])
with comp.static_group("do_mul", 1) as do_mul:
mul.left = top
mul.right = left

with comp.continuous:
this.out = acc.out
comp.this().out = acc.out

comp.control += par
comp.control += cb.static_par(do_add, do_mul)
22 changes: 13 additions & 9 deletions frontends/systolic-lang/gen_post_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def imm_write_mem_groups(comp: cb.ComponentBuilder, row_num: int, perform_relu:

def imm_write_mem_post_op(
prog: cb.Builder, config: SystolicConfiguration, perform_relu: bool
):
) -> cb.ComponentBuilder:
"""
This post-op does nothing except immediately write to memory.
If perform_relu is true, then writes 0 to memory if result < 0, and writes
Expand All @@ -137,20 +137,22 @@ def imm_write_mem_post_op(
+ [py_ast.Enable(f"write_r{r}") for r in range(num_rows)]
)

return comp


def default_post_op(prog: cb.Builder, config: SystolicConfiguration):
"""
Default post op that immediately writes to output memory.
"""
imm_write_mem_post_op(prog=prog, config=config, perform_relu=False)
return imm_write_mem_post_op(prog=prog, config=config, perform_relu=False)


def relu_post_op(prog: cb.Builder, config: SystolicConfiguration):
"""
Relu post op that (combinationally) performs relu before
immediately writing the result to memory.
"""
imm_write_mem_post_op(prog=prog, config=config, perform_relu=True)
return imm_write_mem_post_op(prog=prog, config=config, perform_relu=True)


def add_dynamic_op_params(comp: cb.ComponentBuilder, idx_width: int):
Expand Down Expand Up @@ -206,7 +208,7 @@ def leaky_relu_comp(prog: cb.Builder, idx_width: int) -> cb.ComponentBuilder:
this.out_mem_write_data = lt.out @ fp_mult.out
g.done = this.out_mem_done

comp.control = py_ast.Enable("do_relu")
comp.control += g

return comp

Expand Down Expand Up @@ -242,7 +244,7 @@ def relu_dynamic_comp(prog: cb.Builder, idx_width: int):
# It takes one cycle to write to g
g.done = this.out_mem_done

comp.control = py_ast.Enable("do_relu")
comp.control += g

return comp

Expand Down Expand Up @@ -316,7 +318,7 @@ def build_assignment(wire, register, output_val):
comp.continuous.asgn(
wire.port("in"),
output_val.out,
register.port("out") == cb.ExprBuilder(py_ast.ConstantPort(BITWIDTH, col)),
register.port("out") == cb.const(BITWIDTH, col),
)

# Current value we are performing relu on.
Expand Down Expand Up @@ -367,7 +369,7 @@ def build_assignment(wire, register, output_val):

op_instance.go = (
row_ready_wire.out & (~row_finished_wire.out) & (~op_instance.done)
) @ cb.ExprBuilder(py_ast.ConstantPort(1, 1))
) @ cb.HI
# input ports for relu_instance
op_instance.value = cur_val.out
op_instance.idx = idx_reg.out
Expand Down Expand Up @@ -401,11 +403,13 @@ def dynamic_post_op(

comp.control = py_ast.StaticParComp(all_groups)

return comp


def leaky_relu_post_op(prog: cb.Builder, config: SystolicConfiguration):
_, num_cols = config.get_output_dimensions()
leaky_relu_op_comp = leaky_relu_comp(prog, idx_width=bits_needed(num_cols))
dynamic_post_op(
return dynamic_post_op(
prog=prog,
config=config,
post_op_component_name=LEAKY_RELU_POST_OP,
Expand All @@ -416,7 +420,7 @@ def leaky_relu_post_op(prog: cb.Builder, config: SystolicConfiguration):
def relu_dynamic_post_op(prog: cb.Builder, config: SystolicConfiguration):
_, num_cols = config.get_output_dimensions()
relu_dynamic_op_comp = relu_dynamic_comp(prog, idx_width=bits_needed(num_cols))
dynamic_post_op(
return dynamic_post_op(
prog=prog,
config=config,
post_op_component_name=RELU_DYNAMIC_POST_OP,
Expand Down

0 comments on commit cf8224e

Please sign in to comment.