Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multidimensional Memory in calyx-py #2240

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,24 @@ def comb_mem_d2(
is_external,
is_ref,
)

def comb_mem_dn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider a more descriptive name. dn reads like "down" to me! Also flesh out the docstring to explain what's up.

self,
name: str,
bitwidth: int,
lens: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lens: List[int],
lengths: List[int],

idx_size: int,
is_external: bool = False,
is_ref: bool = False,
) -> CellBuilder:
"""Generate a StdMemD1 cell that abstracts to an n-dimensional memory."""
self.prog.import_("primitives/memories/comb.futil")
prod = 1
for l in lens:
prod *= l
return self.cell(
name, ast.Stdlib.comb_mem_d1(bitwidth, prod, idx_size), is_external, is_ref
)

def seq_mem_d1(
self,
Expand Down Expand Up @@ -482,6 +500,24 @@ def seq_mem_d2(
is_ref,
)

def seq_mem_dn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

self,
name: str,
bitwidth: int,
lens : List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

idx_size: int,
is_external: bool = False,
is_ref: bool=False
) -> CellBuilder:
"""Generate a SeqMemD1 cell that abstracts to an n-dimensional memory."""
self.prog.import_("primitives/memories/seq.futil")
prod = 1
for l in lens:
prod *= l
return self.cell(
name, ast.Stdlib.seq_mem_d1(bitwidth, prod, idx_size), is_external, is_ref
)

def binary(
self,
operation: str,
Expand Down Expand Up @@ -898,6 +934,48 @@ def mem_latch_d2(self, mem, i, j, groupname):
latch_grp.done = mem.done
return latch_grp

def flatten_idx(self, dims, indices):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be private?

"""Translate an n-dimensional index into a corresponding 1d index"""
assert len(dims) == len(indices)
i = len(indices) - 1
prod = 1
total = indices[-1]
while i > 0:
prod *= dims[i]
total += prod * indices[i-1]
i -= 1
return total

def mem_load_dn(self, mem, dims, indices, reg, groupname):
"""Inserts wiring into `self` to perform `reg := mem[i1][i2]...[in]`,
where `mem` is a seq_dn memory or a comb_mem_d1 memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, also explain the use-case of how indices has to be used. And consider the assertion that Rachit suggests below!

"""
assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1()
is_comb = mem.is_comb_mem_d1()
with self.group(groupname) as load_grp:
mem.addr0 = self.flatten_idx(dims, indices)
if is_comb:
reg.write_en = 1
reg.in_ = mem.read_data
else:
mem.content_en = 1
reg.write_en = mem.done @ 1
reg.in_ = mem.done @ mem.read_data
load_grp.done = reg.done
return load_grp

def mem_latch_dn(self, mem, dims, indices, groupname):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if I'm understanding this correctly, this requires indices to be constants at compile-time. If so, it would be useful to add an assert to check that all the indices are in fact fact and throw an error if they are not.

"""Inserts wiring into `self` to latch `mem[i]`,
where `mem` is a seq_mem_d1 memory.
A user can later read `mem.out` and get the latched value.
"""
assert mem.is_seq_mem_d1()
with self.group(groupname) as latch_grp:
mem.addr0 = self.flatten_idx(dims, indices)
mem.content_en = HI
latch_grp.done = mem.done
return latch_grp

def mem_store_d1(self, mem, i, val, groupname):
"""Inserts wiring into `self` to perform `mem[i] := val`,
where `mem` is a seq_d1 memory or a comb_mem_d1 memory
Expand Down Expand Up @@ -928,6 +1006,21 @@ def mem_store_d2(self, mem, i, j, val, groupname):
if not is_comb:
mem.content_en = 1
return store_grp

def mem_store_dn(self, mem, dims, indices, val, groupname):
"""Inserts wiring into `self` to perform `mem[i] := val`,
where `mem` is a seq_d2 memory or a comb_mem_d2 memory
"""
assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1()
is_comb = mem.is_comb_mem_d1()
with self.group(groupname) as store_grp:
mem.addr0 = self.flatten_idx(dims, indices)
mem.write_en = 1
mem.write_data = val
store_grp.done = mem.done
if not is_comb:
mem.content_en = 1
return store_grp

def mem_load_to_mem(self, mem, i, ans, j, groupname):
"""Inserts wiring into `self` to perform `ans[j] := mem[i]`,
Expand Down
Loading