Skip to content

Commit

Permalink
Static Systolic Array (#1740)
Browse files Browse the repository at this point in the history
* added static option for systolic array

* flake8

* tests don't duplicate data

* reformat test
  • Loading branch information
calebmkim committed Oct 16, 2023
1 parent 57d1977 commit d4b19b2
Show file tree
Hide file tree
Showing 27 changed files with 738 additions and 1,457 deletions.
90 changes: 55 additions & 35 deletions frontends/systolic-lang/gen_array_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,32 +246,35 @@ def get_pe_invoke(r, c, mul_ready):
)


def init_iter_limit(comp: cb.ComponentBuilder, depth_port, partial_iter_limit):
def init_iter_limit(
comp: cb.ComponentBuilder, depth_port, config: SystolicConfiguration
):
"""
Builds group that instantiates the dynamic/runtime values for the systolic
array: its depth and iteration limit/count (since its iteration limit depends on
its depth).
iteration limit = depth + partial_iter_limit
"""
iter_limit = comp.reg("iter_limit", BITWIDTH)
iter_limit_add = comp.add(BITWIDTH, "iter_limit_add")
with comp.static_group("init_iter_limit", 1):
iter_limit_add.left = partial_iter_limit
iter_limit_add.right = depth_port
iter_limit.in_ = iter_limit_add.out
iter_limit.write_en = 1
# Only need to initalize this group if
if not config.static:
partial_iter_limit = config.top_length + config.left_length + 4
iter_limit = comp.reg("iter_limit", BITWIDTH)
iter_limit_add = comp.add(BITWIDTH, "iter_limit_add")
with comp.static_group("init_iter_limit", 1):
iter_limit_add.left = partial_iter_limit
iter_limit_add.right = depth_port
iter_limit.in_ = iter_limit_add.out
iter_limit.write_en = 1


def instantiate_idx_groups(comp: cb.ComponentBuilder):
def instantiate_idx_groups(comp: cb.ComponentBuilder, config: SystolicConfiguration):
"""
Builds groups that instantiate idx to 0 and increment idx.
Also builds groups that set cond_reg to 1 (runs before the while loop)
and that sets cond_reg to (idx + 1 < iter_limit).
"""
idx = comp.reg("idx", BITWIDTH)
add = comp.add(BITWIDTH, "idx_add")
iter_limit = comp.get_cell("iter_limit")
lt_iter_limit = comp.lt(BITWIDTH, "lt_iter_limit")

with comp.static_group("init_idx", 1):
idx.in_ = 0
Expand All @@ -281,9 +284,12 @@ def instantiate_idx_groups(comp: cb.ComponentBuilder):
add.right = 1
idx.in_ = add.out
idx.write_en = 1
with comp.continuous:
lt_iter_limit.left = idx.out
lt_iter_limit.right = iter_limit.out
if not config.static:
iter_limit = comp.get_cell("iter_limit")
lt_iter_limit = comp.lt(BITWIDTH, "lt_iter_limit")
with comp.continuous:
lt_iter_limit.left = idx.out
lt_iter_limit.right = iter_limit.out


def instantiate_calyx_adds(comp, nec_ranges) -> list:
Expand Down Expand Up @@ -396,8 +402,22 @@ def gen_schedules(
`pe_write_sched` contains when to "write" the PE value into the output ports
(e.g., this.r0_valid)
"""

def depth_plus_const(const: int):
"""
Returns depth + const. If config.static, then this is an int.
Otherwise, we need to perform a Calyx addition to figure this out.
"""
if config.static:
# return an int
return config.get_contraction_dimension() + const
else:
# return a CalyxAdd object, whose value is determined after generation
depth_port = comp.this().depth
return CalyxAdd(depth_port, const)

left_length, top_length = config.left_length, config.top_length
depth_port = comp.this().depth

schedules = {}
update_sched = np.zeros((left_length, top_length), dtype=object)
pe_fill_sched = np.zeros((left_length, top_length), dtype=object)
Expand All @@ -407,13 +427,13 @@ def gen_schedules(
for row in range(0, left_length):
for col in range(0, top_length):
pos = row + col
update_sched[row][col] = (pos, CalyxAdd(depth_port, pos))
update_sched[row][col] = (pos, depth_plus_const(pos))
pe_fill_sched[row][col] = (pos + 1, pos + 5)
pe_accum_sched[row][col] = (pos + 5, CalyxAdd(depth_port, pos + 5))
pe_move_sched[row][col] = (pos + 1, CalyxAdd(depth_port, pos + 1))
pe_accum_sched[row][col] = (pos + 5, depth_plus_const(pos + 5))
pe_move_sched[row][col] = (pos + 1, depth_plus_const(pos + 1))
pe_write_sched[row][col] = (
CalyxAdd(depth_port, pos + 5),
CalyxAdd(depth_port, pos + 6),
depth_plus_const(pos + 5),
depth_plus_const(pos + 6),
)
schedules["update_sched"] = update_sched
schedules["fill_sched"] = pe_fill_sched
Expand Down Expand Up @@ -458,15 +478,12 @@ def generate_control(
control = []
top_length, left_length = config.top_length, config.left_length

# Initialize all memories.
control.append(
py_ast.StaticParComp(
[
py_ast.Enable("init_idx"),
py_ast.Enable("init_iter_limit"),
]
)
)
# Initialize the idx and iteration_limit.
# We only need to initialize iteration_limit for dynamic configurations
init_groups = [py_ast.Enable("init_idx")]
if not config.static:
init_groups += [py_ast.Enable("init_iter_limit")]
control.append(py_ast.StaticParComp(init_groups))

# source_pos metadata init
init_tag = 0
Expand Down Expand Up @@ -532,11 +549,16 @@ def counter():
while_body = py_ast.StaticParComp(while_body_stmts)

# build the while loop with condition cond_reg.
cond_reg_port = comp.get_cell("lt_iter_limit").port("out")
while_loop = cb.while_(cond_reg_port, while_body)
if config.static:
while_loop = cb.static_repeat(config.get_iteration_count(), while_body)
else:
cond_reg_port = comp.get_cell("lt_iter_limit").port("out")
while_loop = cb.while_(cond_reg_port, while_body)

control.append(while_loop)

if config.static:
return py_ast.StaticSeqComp(stmts=control), source_map
return py_ast.SeqComp(stmts=control), source_map


Expand All @@ -551,9 +573,7 @@ def create_systolic_array(prog: cb.Builder, config: SystolicConfiguration):
computational_unit = prog.component(SYSTOLIC_ARRAY_COMP)
depth_port = computational_unit.input("depth", BITWIDTH)
# initialize the iteration limit to top_length + left_length + depth + 4
init_iter_limit(
computational_unit, depth_port, config.top_length + config.left_length + 4
)
init_iter_limit(computational_unit, depth_port, config)

schedules = gen_schedules(config, computational_unit)
nec_ranges = set()
Expand All @@ -562,7 +582,7 @@ def create_systolic_array(prog: cb.Builder, config: SystolicConfiguration):
instantiate_calyx_adds(computational_unit, nec_ranges)

# instantiate groups that handles the idx variables
instantiate_idx_groups(computational_unit)
instantiate_idx_groups(computational_unit, config)
list1, list2 = zip(*nec_ranges)
nec_ranges_beg = set(list1)
nec_ranges_end = set(list2)
Expand Down
28 changes: 28 additions & 0 deletions frontends/systolic-lang/systolic_arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def parse_arguments(self):
parser.add_argument("-ll", "--left-length", type=int)
parser.add_argument("-ld", "--left-depth", type=int)
parser.add_argument("-p", "--post-op", type=str, default=None)
parser.add_argument("-s", "--static", action="store_true")

args = parser.parse_args()

Expand All @@ -37,6 +38,7 @@ def parse_arguments(self):
self.left_length = args.left_length
self.left_depth = args.left_depth
self.post_op = args.post_op
self.static = args.static
elif args.file is not None:
with open(args.file, "r") as f:
spec = json.load(f)
Expand All @@ -46,6 +48,8 @@ def parse_arguments(self):
self.left_depth = spec["left_depth"]
# default to not perform leaky_relu
self.post_op = spec.get("post_op", None)
# default to non-static (i.e., dynamic contraction dimension)
self.static = spec.get("static", False)
else:
parser.error(
"Need to pass either `FILE` or all of `"
Expand All @@ -63,3 +67,27 @@ def get_output_dimensions(self):
of num_rows x num_cols)
"""
return (self.left_length, self.top_length)

def get_contraction_dimension(self):
"""
Returns the contraction dimension
"""
assert (
self.left_depth == self.top_depth
), "left_depth and top_depth should be same"
# Could have also returend self.top_depth
return self.left_depth

def get_iteration_count(self):
"""
Returns the iteration count if self.static
Otherwise throws an error
"""
# Could have also returend self.top_depth
if self.static:
(num_out_rows, num_out_cols) = self.get_output_dimensions()
return self.get_contraction_dimension() + num_out_rows + num_out_cols + 4
raise Exception(
"Cannot get iteration count for systolic array with dynamic \
contraction dimension"
)
73 changes: 69 additions & 4 deletions runt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,77 @@ fud e --from systolic --to jq \
"""

[[tests]]
name = "[frontend] systolic array output correctness"
name = "[frontend] systolic array mmult static correctness"
paths = [
"tests/correctness/systolic/mmult-inputs/*.data"
]
cmd = """
fud e --from systolic --to jq \
--through verilog \
--through dat \
-s verilog.data {} \
-s calyx.exec './target/debug/calyx' \
-s calyx.flags "-d well-formed" \
-s jq.expr ".memories" \
{}_static.systolic -q
"""
expect_dir="tests/correctness/systolic/mmult-expect"

[[tests]]
name = "[frontend] systolic array mmult dynamic correctness"
paths = [
"tests/correctness/systolic/mmult-inputs/*.data"
]
cmd = """
fud e --from systolic --to jq \
--through verilog \
--through dat \
-s verilog.data {} \
-s calyx.exec './target/debug/calyx' \
-s calyx.flags "-d well-formed" \
-s jq.expr ".memories" \
{}_dynamic.systolic -q
"""
expect_dir="tests/correctness/systolic/mmult-expect"

[[tests]]
name = "[frontend] systolic array relu static correctness"
paths = [
"tests/correctness/systolic/relu-inputs/*.data"
]
cmd = """
fud e --from systolic --to jq \
--through verilog \
--through dat \
-s verilog.data {} \
-s calyx.exec './target/debug/calyx' \
-s calyx.flags "-d well-formed" \
-s jq.expr ".memories" \
{}_static.systolic -q
"""
expect_dir="tests/correctness/systolic/relu-expect"

[[tests]]
name = "[frontend] systolic array relu dynamic correctness"
paths = [
"tests/correctness/systolic/relu-inputs/*.data"
]
cmd = """
fud e --from systolic --to jq \
--through verilog \
--through dat \
-s verilog.data {} \
-s calyx.exec './target/debug/calyx' \
-s calyx.flags "-d well-formed" \
-s jq.expr ".memories" \
{}_dynamic.systolic -q
"""
expect_dir="tests/correctness/systolic/relu-expect"

[[tests]]
name = "[frontend] systolic array leaky relu correctness"
paths = [
"tests/correctness/systolic/output/*.systolic",
"tests/correctness/systolic/leaky-relu/*.systolic",
"tests/correctness/systolic/relu/*.systolic",
"tests/correctness/systolic/relu-dynamic/*.systolic",
]
cmd = """
fud e --from systolic --to dat \
Expand Down
44 changes: 44 additions & 0 deletions tests/correctness/systolic/mmult-expect/array-2-3-4.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"l0": [
"-1.5258636474609375",
"1.37969970703125",
"3.964019775390625"
],
"l1": [
"-4.90814208984375",
"1.1556854248046875",
"0.088134765625"
],
"out_mem_0": [
"17.829559326171875",
"-14.9852752685546875",
"-0.7954864501953125",
"-15.9398193359375"
],
"out_mem_1": [
"13.2156982421875",
"-5.021575927734375",
"20.450347900390625",
"-28.945556640625"
],
"t0": [
"-3.7752532958984375",
"-4.9618072509765625",
"4.771636962890625"
],
"t1": [
"0.8634185791015625",
"-0.4265594482421875",
"-3.29949951171875"
],
"t2": [
"-3.8273162841796875",
"1.6114501953125",
"-2.2347869873046875"
],
"t3": [
"4.7815093994140625",
"-4.69775390625",
"-0.545501708984375"
]
}
Loading

0 comments on commit d4b19b2

Please sign in to comment.