Skip to content

Commit

Permalink
NTT: using a few "modern" eDSL features (#2034)
Browse files Browse the repository at this point in the history
* Neaten mul_group

* Neaten another helper

* A little more neatening

* Less use of Cell

* Reduce manual imports

* Catch up to main

* Stray doctring issue
  • Loading branch information
anshumanmohan authored May 13, 2024
1 parent 015da80 commit 0adf64d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 110 deletions.
34 changes: 34 additions & 0 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,18 @@ def sub(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdSub cell."""
return self.binary("sub", size, name, signed)

def div_pipe(
self, size: int, name: str = None, signed: bool = False
) -> CellBuilder:
"""Generate a Div_Pipe cell."""
return self.binary("div_pipe", size, name, signed)

def mult_pipe(
self, size: int, name: str = None, signed: bool = False
) -> CellBuilder:
"""Generate a Mult_Pipe cell."""
return self.binary("mult_pipe", size, name, signed)

def gt(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdGt cell."""
return self.binary("gt", size, name, signed)
Expand Down Expand Up @@ -466,6 +478,28 @@ def binary_use(self, left, right, cell, groupname=None):
cell.right = right
return CellAndGroup(cell, comb_group)

def binary_use_names(self, cellname, leftname, rightname, groupname=None):
"""Accepts the name of a cell that performs some computation on two values.
Accepts the names of cells that contain those two values.
Creates a group that wires up the cell with those values.
Returns the group created.
group `groupname` {
`cellname`.left = `leftname`.out;
`cellname`.right = `rightname`.out;
`groupname`.go = 1;
`groupname`.done = `cellname`.done;
}
"""
cell = self.get_cell(cellname)
groupname = groupname or f"{cellname}_group"
with self.group(groupname) as group:
cell.left = self.get_cell(leftname).out
cell.right = self.get_cell(rightname).out
cell.go = HI
group.done = cell.done
return group

def try_infer_width(self, width, left, right):
"""If `width` is None, try to infer it from `left` or `right`.
If that fails, raise an error.
Expand Down
110 changes: 28 additions & 82 deletions frontends/ntt-pipeline/gen-ntt-pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from prettytable import PrettyTable
import numpy as np
import calyx.py_ast as ast
from calyx.py_ast import CompVar, Cell, Stdlib
import calyx.builder as cb
from calyx.utils import bits_needed

Expand All @@ -28,7 +27,7 @@ def reduce_parallel_control_pass(component: ast.Component, N: int, input_size: i
...
"""
assert (
N is not None and 0 < N < input_size and (not (N & (N - 1)))
N and 0 < N < input_size and (not (N & (N - 1)))
), f"""N: {N} should be a power of two within bounds (0, {input_size})."""

reduced_controls = []
Expand Down Expand Up @@ -168,15 +167,12 @@ def fresh_comp_index(op):

def mul_group(comp: cb.ComponentBuilder, stage, mul_tuple):
mul_index, k, phi_index = mul_tuple

mul = comp.get_cell(f"mult_pipe{mul_index}")
phi = comp.get_cell(f"phi{phi_index}")
reg = comp.get_cell(f"r{k}")
with comp.group(f"s{stage}_mul{mul_index}") as g:
mul.left = phi.out
mul.right = reg.out
mul.go = 1
g.done = mul.done
comp.binary_use_names(
f"mult_pipe{mul_index}",
f"phi{phi_index}",
f"r{k}",
f"s{stage}_mul{mul_index}",
)

def op_mod_group(comp: cb.ComponentBuilder, stage, row, operations_tuple):
lhs, op, mul_index = operations_tuple
Expand All @@ -201,10 +197,7 @@ def op_mod_group(comp: cb.ComponentBuilder, stage, row, operations_tuple):
def precursor_group(comp: cb.ComponentBuilder, row):
r = comp.get_cell(f"r{row}")
A = comp.get_cell(f"A{row}")
with comp.group(f"precursor_{row}") as g:
r.in_ = A.out
r.write_en = 1
g.done = r.done
comp.reg_store(r, A.out, f"precursor_{row}")

def preamble_group(comp: cb.ComponentBuilder, row):
reg = comp.get_cell(f"r{row}")
Expand All @@ -223,70 +216,24 @@ def preamble_group(comp: cb.ComponentBuilder, row):
def epilogue_group(comp: cb.ComponentBuilder, row):
input = comp.get_cell("a")
A = comp.get_cell(f"A{row}")
with comp.group(f"epilogue_{row}") as epilogue:
input.addr0 = row
input.write_en = 1
input.write_data = A.out
epilogue.done = input.done

def cells():
input = CompVar("a")
phis = CompVar("phis")
memories = [
Cell(
input, Stdlib.comb_mem_d1(input_bitwidth, n, bitwidth), is_external=True
),
Cell(
phis, Stdlib.comb_mem_d1(input_bitwidth, n, bitwidth), is_external=True
),
]
r_regs = [
Cell(CompVar(f"r{r}"), Stdlib.register(input_bitwidth)) for r in range(n)
]
A_regs = [
Cell(CompVar(f"A{r}"), Stdlib.register(input_bitwidth)) for r in range(n)
]
mul_regs = [
Cell(CompVar(f"mul{i}"), Stdlib.register(input_bitwidth))
for i in range(n // 2)
]
phi_regs = [
Cell(CompVar(f"phi{r}"), Stdlib.register(input_bitwidth)) for r in range(n)
]
mod_pipes = [
Cell(
CompVar(f"mod_pipe{r}"),
Stdlib.op("div_pipe", input_bitwidth, signed=True),
)
for r in range(n)
]
mult_pipes = [
Cell(
CompVar(f"mult_pipe{i}"),
Stdlib.op("mult_pipe", input_bitwidth, signed=True),
)
for i in range(n // 2)
]
adds = [
Cell(CompVar(f"add{i}"), Stdlib.op("add", input_bitwidth, signed=True))
for i in range(n // 2)
]
subs = [
Cell(CompVar(f"sub{i}"), Stdlib.op("sub", input_bitwidth, signed=True))
for i in range(n // 2)
]

return (
memories
+ r_regs
+ A_regs
+ mul_regs
+ phi_regs
+ mod_pipes
+ mult_pipes
+ adds
+ subs
)
comp.mem_store_comb_mem_d1(input, row, A.out, f"epilogue_{row}")

def insert_cells(comp: cb.ComponentBuilder):
# memories
comp.comb_mem_d1("a", input_bitwidth, n, bitwidth, is_external=True)
comp.comb_mem_d1("phis", input_bitwidth, n, bitwidth, is_external=True)

for r in range(n):
comp.reg(input_bitwidth, f"r{r}") # r_regs
comp.reg(input_bitwidth, f"A{r}") # A_regs
comp.reg(input_bitwidth, f"phi{r}") # phi_regs
comp.div_pipe(input_bitwidth, f"mod_pipe{r}", signed=True) # mod_pipes

for i in range(n // 2):
comp.reg(input_bitwidth, f"mult{i}") # mul_regs
comp.mult_pipe(input_bitwidth, f"mult_pipe{i}", signed=True) # mult_pipes
comp.add(input_bitwidth, f"add{i}", signed=True) # adds
comp.sub(input_bitwidth, f"sub{i}", signed=True) # subs

def wires(main: cb.ComponentBuilder):
for r in range(n):
Expand Down Expand Up @@ -325,9 +272,8 @@ def control():

pp_table(operations, multiplies, n, num_stages)
prog = cb.Builder()
prog.import_("primitives/binary_operators.futil")
prog.import_("primitives/memories/comb.futil")
main = prog.component("main", cells())
main = prog.component("main")
insert_cells(main)
wires(main)
main.component.controls = control()
return prog.program
Expand Down
28 changes: 14 additions & 14 deletions tests/frontend/ntt-pipeline/ntt-4-reduced-2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,35 @@
// | 3 | a[1] - a[3] * phis[1] | a[2] - a[3] * phis[3] |
// +---+-----------------------+-----------------------+
import "primitives/core.futil";
import "primitives/binary_operators.futil";
import "primitives/memories/comb.futil";
import "primitives/binary_operators.futil";
component main() -> () {
cells {
@external a = comb_mem_d1(32, 4, 3);
@external phis = comb_mem_d1(32, 4, 3);
r0 = std_reg(32);
r1 = std_reg(32);
r2 = std_reg(32);
r3 = std_reg(32);
A0 = std_reg(32);
A1 = std_reg(32);
A2 = std_reg(32);
A3 = std_reg(32);
mul0 = std_reg(32);
mul1 = std_reg(32);
phi0 = std_reg(32);
phi1 = std_reg(32);
phi2 = std_reg(32);
phi3 = std_reg(32);
mod_pipe0 = std_sdiv_pipe(32);
r1 = std_reg(32);
A1 = std_reg(32);
phi1 = std_reg(32);
mod_pipe1 = std_sdiv_pipe(32);
r2 = std_reg(32);
A2 = std_reg(32);
phi2 = std_reg(32);
mod_pipe2 = std_sdiv_pipe(32);
r3 = std_reg(32);
A3 = std_reg(32);
phi3 = std_reg(32);
mod_pipe3 = std_sdiv_pipe(32);
mult0 = std_reg(32);
mult_pipe0 = std_smult_pipe(32);
mult_pipe1 = std_smult_pipe(32);
add0 = std_sadd(32);
add1 = std_sadd(32);
sub0 = std_ssub(32);
mult1 = std_reg(32);
mult_pipe1 = std_smult_pipe(32);
add1 = std_sadd(32);
sub1 = std_ssub(32);
}
wires {
Expand Down
28 changes: 14 additions & 14 deletions tests/frontend/ntt-pipeline/ntt-4.expect
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,35 @@
// | 3 | a[1] - a[3] * phis[1] | a[2] - a[3] * phis[3] |
// +---+-----------------------+-----------------------+
import "primitives/core.futil";
import "primitives/binary_operators.futil";
import "primitives/memories/comb.futil";
import "primitives/binary_operators.futil";
component main() -> () {
cells {
@external a = comb_mem_d1(32, 4, 3);
@external phis = comb_mem_d1(32, 4, 3);
r0 = std_reg(32);
r1 = std_reg(32);
r2 = std_reg(32);
r3 = std_reg(32);
A0 = std_reg(32);
A1 = std_reg(32);
A2 = std_reg(32);
A3 = std_reg(32);
mul0 = std_reg(32);
mul1 = std_reg(32);
phi0 = std_reg(32);
phi1 = std_reg(32);
phi2 = std_reg(32);
phi3 = std_reg(32);
mod_pipe0 = std_sdiv_pipe(32);
r1 = std_reg(32);
A1 = std_reg(32);
phi1 = std_reg(32);
mod_pipe1 = std_sdiv_pipe(32);
r2 = std_reg(32);
A2 = std_reg(32);
phi2 = std_reg(32);
mod_pipe2 = std_sdiv_pipe(32);
r3 = std_reg(32);
A3 = std_reg(32);
phi3 = std_reg(32);
mod_pipe3 = std_sdiv_pipe(32);
mult0 = std_reg(32);
mult_pipe0 = std_smult_pipe(32);
mult_pipe1 = std_smult_pipe(32);
add0 = std_sadd(32);
add1 = std_sadd(32);
sub0 = std_ssub(32);
mult1 = std_reg(32);
mult_pipe1 = std_smult_pipe(32);
add1 = std_sadd(32);
sub1 = std_ssub(32);
}
wires {
Expand Down

0 comments on commit 0adf64d

Please sign in to comment.