Skip to content

Commit

Permalink
No more guard conjunction to declare group done signals (#2131)
Browse files Browse the repository at this point in the history
* One less guard disjunct in gen_exp

* No guard conj in sdn.py

* No more & in ntt

* No more guard conj in tuple

* No more group conj in gen_exp

* Stray expect file
  • Loading branch information
anshumanmohan authored Jun 13, 2024
1 parent e18da1b commit 87f8462
Show file tree
Hide file tree
Showing 21 changed files with 220 additions and 94 deletions.
27 changes: 18 additions & 9 deletions calyx-py/calyx/gen_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,17 @@ def generate_fp_pow_component(
)

# groups
with comp.group("init") as init:
with comp.group("init_pow") as init_pow:
pow.in_ = FixedPoint(
"1.0", width, int_width, is_signed=is_signed
).unsigned_integer()
pow.write_en = 1
init_pow.done = pow.done

with comp.group("init_count") as init_count:
count.in_ = 0
count.write_en = 1
init.done = (pow.done & count.done) @ 1
init_count.done = count.done

with comp.group("execute_mul") as execute_mul:
mul.left = comp.this().base
Expand All @@ -73,7 +76,10 @@ def generate_fp_pow_component(
with comp.continuous:
comp.this().out = pow.out

comp.control += [init, while_with(cond, par(execute_mul, incr_count))]
comp.control += [
par(init_pow, init_count),
while_with(cond, par(execute_mul, incr_count)),
]

return comp.component

Expand Down Expand Up @@ -312,18 +318,20 @@ def generate_groups(
int_x = comp.get_cell("int_x")
frac_x = comp.get_cell("frac_x")
one = comp.get_cell("one")
with comp.group("split_bits") as split_bits:
with comp.group("split_bits_int_x") as split_bits_int_x:
and0.left = input.out
and0.right = const(width, 2**width - 2**frac_width)
rsh.left = and0.out
rsh.right = const(width, frac_width)
int_x.write_en = 1
int_x.in_ = rsh.out
split_bits_int_x.done = int_x.done
with comp.group("split_bits_frac_x") as split_bits_frac_x:
and1.left = input.out
and1.right = const(width, (2**frac_width) - 1)
int_x.write_en = 1
frac_x.write_en = 1
int_x.in_ = rsh.out
frac_x.in_ = and1.out
split_bits.done = (int_x.done & frac_x.done) @ 1
split_bits_frac_x.done = frac_x.done

if is_signed:
mult_pipe = comp.get_cell("mult_pipe1")
Expand Down Expand Up @@ -411,7 +419,8 @@ def generate_control(comp: ComponentBuilder, degree: int, is_signed: bool):
if is_signed:
lt = comp.get_cell("lt")
init = comp.get_group("init")
split_bits = comp.get_group("split_bits")
split_bits_int_x = comp.get_group("split_bits_int_x")
split_bits_frac_x = comp.get_group("split_bits_frac_x")

# TODO (griffin): This is a hack to avoid inserting empty seqs. Maybe worth
# moving into the add method of ControlBuilder?
Expand All @@ -430,7 +439,7 @@ def generate_control(comp: ComponentBuilder, degree: int, is_signed: bool):
if is_signed
else []
),
split_bits,
par(split_bits_int_x, split_bits_frac_x),
pow_invokes,
consume_pow,
mult_by_reciprocal,
Expand Down
9 changes: 6 additions & 3 deletions calyx-py/test/correctness/queues/sdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,18 @@ def insert_controller(prog, name, stats_component):
count_0 = controller.reg(32)
count_1 = controller.reg(32)

with controller.group("get_data_locally") as get_data_locally:
with controller.group("get_data_locally_count0") as get_data_locally_count0:
count_0.in_ = stats.count_0
count_0.write_en = 1
get_data_locally_count0.done = count_0.done

with controller.group("get_data_locally_count1") as get_data_locally_count1:
count_1.in_ = stats.count_1
count_1.write_en = 1
get_data_locally.done = (count_0.done & count_1.done) @ 1
get_data_locally_count1.done = count_1.done

# The main logic.
controller.control += get_data_locally
controller.control += cb.par(get_data_locally_count0, get_data_locally_count1)
# Great, now I have the data around locally.

return controller
Expand Down
10 changes: 7 additions & 3 deletions calyx-py/test/correctness/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,23 @@ def insert_main(prog):
mem1.content_en = cb.HI
run_tuplify.done = mem1.done

with comp.group("run_untuplify") as run_untuplify:
with comp.group("run_untuplify_fst") as run_untuplify_fst:
untuplify.tup = cb.const(64, 17179869186)
mem2.addr0 = cb.const(1, 0)
mem2.write_en = cb.HI
mem2.write_data = untuplify.fst
mem2.content_en = cb.HI
run_untuplify_fst.done = mem2.done

with comp.group("run_untuplify_snd") as run_untuplify_snd:
untuplify.tup = cb.const(64, 17179869186)
mem3.addr0 = cb.const(1, 0)
mem3.write_en = cb.HI
mem3.write_data = untuplify.snd
mem3.content_en = cb.HI
run_untuplify.done = (mem2.done & mem3.done) @ cb.HI
run_untuplify_snd.done = mem3.done

comp.control += cb.par(run_tuplify, run_untuplify)
comp.control += cb.par(run_tuplify, run_untuplify_fst, run_untuplify_snd)

return comp

Expand Down
22 changes: 18 additions & 4 deletions frontends/ntt-pipeline/gen-ntt-pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,16 @@ def preamble_group(comp: cb.ComponentBuilder, row):
phi = comp.get_cell(f"phi{row}")
input = comp.get_cell("a")
phis = comp.get_cell("phis")
with main.group(f"preamble_{row}") as preamble:
with main.group(f"preamble_{row}_reg") as preamble_reg:
input.addr0 = row
phis.addr0 = row
reg.write_en = 1
reg.in_ = input.read_data
preamble_reg.done = reg.done
with main.group(f"preamble_{row}_phi") as preamble_phi:
phis.addr0 = row
phi.write_en = 1
phi.in_ = phis.read_data
preamble.done = (reg.done & phi.done) @ 1
preamble_phi.done = phi.done

def epilogue_group(comp: cb.ComponentBuilder, row):
input = comp.get_cell("a")
Expand Down Expand Up @@ -250,7 +252,19 @@ def wires(main: cb.ComponentBuilder):
epilogue_group(main, r)

def control():
preambles = [ast.SeqComp([ast.Enable(f"preamble_{r}") for r in range(n)])]
preambles = [
ast.SeqComp(
[
ast.ParComp(
[
ast.Enable(f"preamble_{r}_reg"),
ast.Enable(f"preamble_{r}_phi"),
]
)
for r in range(n)
]
)
]
epilogues = [ast.SeqComp([ast.Enable(f"epilogue_{r}") for r in range(n)])]

ntt_stages = []
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/any-base-1.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 213,
"cycles": 212,
"memories": {
"b": [
"4.5"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/any-base-2.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 208,
"cycles": 207,
"memories": {
"b": [
"7.5"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/any-base-3.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 304,
"cycles": 303,
"memories": {
"b": [
"0.75"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/degree-4-signed.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 56,
"cycles": 55,
"memories": {
"ret": [
"2.7182769775390625"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/degree-4-unsigned.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 49,
"cycles": 48,
"memories": {
"ret": [
"2.7182769775390625"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/degree-8-signed.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 134,
"cycles": 133,
"memories": {
"ret": [
"0.0001068115234375"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/degree-8-unsigned.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 75,
"cycles": 74,
"memories": {
"ret": [
"9181.710357666015625"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/exp/neg-base.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 221,
"cycles": 220,
"memories": {
"b": [
"-2.600006103515625"
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/ntt-pipeline/ntt-16-reduced-4.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 663,
"cycles": 647,
"memories": {
"a": [
7371,
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/ntt-pipeline/ntt-16.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 216,
"cycles": 200,
"memories": {
"a": [
7371,
Expand Down
2 changes: 1 addition & 1 deletion tests/correctness/ntt-pipeline/ntt-8.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 150,
"cycles": 142,
"memories": {
"a": [
5390,
Expand Down
28 changes: 20 additions & 8 deletions tests/frontend/exp/degree-2-unsigned.expect
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ component exp(x: 32) -> (out: 32) {
exponent_value.in = x;
init[done] = exponent_value.done;
}
group split_bits {
group split_bits_int_x {
and0.left = exponent_value.out;
and0.right = 32'd4294901760;
rsh.left = and0.out;
rsh.right = 32'd16;
int_x.write_en = 1'd1;
int_x.in = rsh.out;
split_bits_int_x[done] = int_x.done;
}
group split_bits_frac_x {
and1.left = exponent_value.out;
and1.right = 32'd65535;
int_x.write_en = 1'd1;
frac_x.write_en = 1'd1;
int_x.in = rsh.out;
frac_x.in = and1.out;
split_bits[done] = (int_x.done & frac_x.done) ? 1'd1;
split_bits_frac_x[done] = frac_x.done;
}
group consume_pow2<"promotable"=1> {
p2.write_en = 1'd1;
Expand Down Expand Up @@ -82,7 +85,10 @@ component exp(x: 32) -> (out: 32) {
control {
seq {
init;
split_bits;
par {
split_bits_int_x;
split_bits_frac_x;
}
par {
invoke pow1(base=e.out, integer_exp=int_x.out)();
invoke pow2(base=frac_x.out, integer_exp=c2.out)();
Expand Down Expand Up @@ -110,12 +116,15 @@ component fp_pow(base: 32, integer_exp: 32) -> (out: 32) {
lt_1 = std_lt(32);
}
wires {
group init {
group init_pow {
pow.in = 32'd65536;
pow.write_en = 1'd1;
init_pow[done] = pow.done;
}
group init_count {
count.in = 32'd0;
count.write_en = 1'd1;
init[done] = (pow.done & count.done) ? 1'd1;
init_count[done] = count.done;
}
group execute_mul {
mul.left = base;
Expand All @@ -140,7 +149,10 @@ component fp_pow(base: 32, integer_exp: 32) -> (out: 32) {
}
control {
seq {
init;
par {
init_pow;
init_count;
}
while lt_1.out with lt_1_group {
par {
execute_mul;
Expand Down
28 changes: 20 additions & 8 deletions tests/frontend/exp/degree-4-signed.expect
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,21 @@ component exp(x: 16) -> (out: 16) {
exponent_value.in = x;
init[done] = exponent_value.done;
}
group split_bits {
group split_bits_int_x {
and0.left = exponent_value.out;
and0.right = 16'd65280;
rsh.left = and0.out;
rsh.right = 16'd8;
int_x.write_en = 1'd1;
int_x.in = rsh.out;
split_bits_int_x[done] = int_x.done;
}
group split_bits_frac_x {
and1.left = exponent_value.out;
and1.right = 16'd255;
int_x.write_en = 1'd1;
frac_x.write_en = 1'd1;
int_x.in = rsh.out;
frac_x.in = and1.out;
split_bits[done] = (int_x.done & frac_x.done) ? 1'd1;
split_bits_frac_x[done] = frac_x.done;
}
group negate {
mult_pipe1.left = exponent_value.out;
Expand Down Expand Up @@ -162,7 +165,10 @@ component exp(x: 16) -> (out: 16) {
if lt.out with is_negative {
negate;
}
split_bits;
par {
split_bits_int_x;
split_bits_frac_x;
}
par {
invoke pow1(base=e.out, integer_exp=int_x.out)();
invoke pow2(base=frac_x.out, integer_exp=c2.out)();
Expand Down Expand Up @@ -203,12 +209,15 @@ component fp_pow(base: 16, integer_exp: 16) -> (out: 16) {
lt_1 = std_slt(16);
}
wires {
group init {
group init_pow {
pow.in = 16'd256;
pow.write_en = 1'd1;
init_pow[done] = pow.done;
}
group init_count {
count.in = 16'd0;
count.write_en = 1'd1;
init[done] = (pow.done & count.done) ? 1'd1;
init_count[done] = count.done;
}
group execute_mul {
mul.left = base;
Expand All @@ -233,7 +242,10 @@ component fp_pow(base: 16, integer_exp: 16) -> (out: 16) {
}
control {
seq {
init;
par {
init_pow;
init_count;
}
while lt_1.out with lt_1_group {
par {
execute_mul;
Expand Down
Loading

0 comments on commit 87f8462

Please sign in to comment.