Skip to content

Commit

Permalink
Add custom combiners (kuznia-rdzeni#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Apr 23, 2024
1 parent 4a08c12 commit 1175c46
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 6 deletions.
78 changes: 78 additions & 0 deletions test/transactron/test_methods.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
import pytest
import random
from amaranth import *
Expand All @@ -13,6 +14,8 @@

from unittest import TestCase

from transactron.utils.assign import AssignArg


class TestDefMethod(TestCaseWithSimulator):
class CircuitTestModule(Elaboratable):
Expand Down Expand Up @@ -556,6 +559,81 @@ def process():
sim.add_sync_process(process)


class CustomCombinerMethodCircuit(Elaboratable):
def elaborate(self, platform):
m = TModule()

self.ready = Signal()
self.running = Signal()

def combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg:
result = C(0)
for i, v in enumerate(args):
result = result ^ Mux(runs[i], v.data, 0)
return {"data": result}

method = Method(i=data_layout(WIDTH), o=data_layout(WIDTH), nonexclusive=True, combiner=combiner)

@def_method(m, method, self.ready)
def _(data: Value):
m.d.comb += self.running.eq(1)
return {"data": data}

m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method))
m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method))

return m


class TestCustomCombinerMethod(TestCaseWithSimulator):
def test_custom_combiner_method(self):
circ = CustomCombinerMethodCircuit()

def process():
for x in range(8):
t1en = bool(x & 1)
t2en = bool(x & 2)
mrdy = bool(x & 4)

val1 = random.randrange(0, 2**WIDTH)
val2 = random.randrange(0, 2**WIDTH)
val1e = val1 if t1en else 0
val2e = val2 if t2en else 0

yield from circ.t1.call_init(data=val1)
yield from circ.t2.call_init(data=val2)

if t1en:
yield from circ.t1.enable()
else:
yield from circ.t1.disable()

if t2en:
yield from circ.t2.enable()
else:
yield from circ.t2.disable()

if mrdy:
yield circ.ready.eq(1)
else:
yield circ.ready.eq(0)

yield Settle()

assert bool((yield circ.running)) == ((t1en or t2en) and mrdy)
assert bool((yield from circ.t1.done())) == (t1en and mrdy)
assert bool((yield from circ.t2.done())) == (t2en and mrdy)

if t1en and mrdy:
assert (yield from circ.t1.get_outputs()) == {"data": val1e ^ val2e}

if t2en and mrdy:
assert (yield from circ.t2.get_outputs()) == {"data": val1e ^ val2e}

with self.run_simulation(circ) as sim:
sim.add_sync_process(process)


class DataDependentConditionalCircuit(Elaboratable):
def __init__(self, n=2, ready_function=lambda arg: arg.data != 3):
self.method = Method(i=data_layout(n))
Expand Down
9 changes: 4 additions & 5 deletions transactron/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def rec(transaction: Transaction, source: TransactionOrMethod):
@staticmethod
def _method_calls(
m: Module, method_map: MethodMap
) -> tuple[Mapping["Method", Sequence[ValueLike]], Mapping["Method", Sequence[ValueLike]]]:
args = defaultdict[Method, list[ValueLike]](list)
runs = defaultdict[Method, list[ValueLike]](list)
) -> tuple[Mapping["Method", Sequence[MethodStruct]], Mapping["Method", Sequence[Value]]]:
args = defaultdict[Method, list[MethodStruct]](list)
runs = defaultdict[Method, list[Value]](list)

for source in method_map.methods_and_transactions:
if isinstance(source, Method):
Expand Down Expand Up @@ -359,8 +359,7 @@ def elaborate(self, platform):
raise RuntimeError(f"Single-caller method '{method.name}' called more than once")

runs = Cat(method_runs[method])
for i in OneHotSwitchDynamic(m, runs):
m.d.comb += method.data_in.eq(method_args[method][i])
m.d.comb += assign(method.data_in, method.combiner(m, method_args[method], runs), fields=AssignType.ALL)

if "TRANSACTRON_VERBOSE" in environ:
self.print_info(cgr, porder, ccs, method_map)
Expand Down
18 changes: 17 additions & 1 deletion transactron/core/method.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from transactron.utils import *
from amaranth import *
from amaranth import tracer
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
i: MethodLayout = (),
o: MethodLayout = (),
nonexclusive: bool = False,
combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None,
single_caller: bool = False,
src_loc: int | SrcLoc = 0,
):
Expand All @@ -77,6 +79,12 @@ def __init__(
transactions in the same clock cycle. If such a situation happens,
the method still is executed only once, and each of the callers
receive its output. Nonexclusive methods cannot have inputs.
combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg
If `nonexclusive` is true, the combiner function combines the
arguments from multiple calls to this method into a single
argument, which is passed to the method body. The third argument
is a bit vector, whose n-th bit is 1 if the n-th call is active
in a given cycle.
single_caller: bool
If true, this method is intended to be called from a single
transaction. An error will be thrown if called from multiple
Expand All @@ -86,17 +94,25 @@ def __init__(
Alternatively, the source location to use instead of the default.
"""
super().__init__(src_loc=get_src_loc(src_loc))

def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg:
ret = Signal(from_method_layout(i))
for k in OneHotSwitchDynamic(m, runs):
m.d.comb += ret.eq(args[k])
return ret

self.owner, owner_name = get_caller_class_name(default="$method")
self.name = name or tracer.get_var_name(depth=2, default=owner_name)
self.ready = Signal(name=self.owned_name + "_ready")
self.run = Signal(name=self.owned_name + "_run")
self.data_in: MethodStruct = Signal(from_method_layout(i))
self.data_out: MethodStruct = Signal(from_method_layout(o))
self.nonexclusive = nonexclusive
self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner
self.single_caller = single_caller
self.validate_arguments: Optional[Callable[..., ValueLike]] = None
if nonexclusive:
assert len(self.data_in.as_value()) == 0
assert len(self.data_in.as_value()) == 0 or combiner is not None

@property
def layout_in(self):
Expand Down

0 comments on commit 1175c46

Please sign in to comment.