diff --git a/test/transactron/test_methods.py b/test/transactron/test_methods.py index 03d9a8707..6cdcc350f 100644 --- a/test/transactron/test_methods.py +++ b/test/transactron/test_methods.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import pytest import random from amaranth import * @@ -13,6 +14,8 @@ from unittest import TestCase +from transactron.utils.assign import AssignArg + class TestDefMethod(TestCaseWithSimulator): class CircuitTestModule(Elaboratable): @@ -556,6 +559,81 @@ def process(): sim.add_sync_process(process) +class CustomCombinerMethodCircuit(Elaboratable): + def elaborate(self, platform): + m = TModule() + + self.ready = Signal() + self.running = Signal() + + def combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: + result = C(0) + for i, v in enumerate(args): + result = result ^ Mux(runs[i], v.data, 0) + return {"data": result} + + method = Method(i=data_layout(WIDTH), o=data_layout(WIDTH), nonexclusive=True, combiner=combiner) + + @def_method(m, method, self.ready) + def _(data: Value): + m.d.comb += self.running.eq(1) + return {"data": data} + + m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method)) + m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method)) + + return m + + +class TestCustomCombinerMethod(TestCaseWithSimulator): + def test_custom_combiner_method(self): + circ = CustomCombinerMethodCircuit() + + def process(): + for x in range(8): + t1en = bool(x & 1) + t2en = bool(x & 2) + mrdy = bool(x & 4) + + val1 = random.randrange(0, 2**WIDTH) + val2 = random.randrange(0, 2**WIDTH) + val1e = val1 if t1en else 0 + val2e = val2 if t2en else 0 + + yield from circ.t1.call_init(data=val1) + yield from circ.t2.call_init(data=val2) + + if t1en: + yield from circ.t1.enable() + else: + yield from circ.t1.disable() + + if t2en: + yield from circ.t2.enable() + else: + yield from circ.t2.disable() + + if mrdy: + yield circ.ready.eq(1) + else: + yield circ.ready.eq(0) + + yield Settle() + + assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) + assert bool((yield from circ.t1.done())) == (t1en and mrdy) + assert bool((yield from circ.t2.done())) == (t2en and mrdy) + + if t1en and mrdy: + assert (yield from circ.t1.get_outputs()) == {"data": val1e ^ val2e} + + if t2en and mrdy: + assert (yield from circ.t2.get_outputs()) == {"data": val1e ^ val2e} + + with self.run_simulation(circ) as sim: + sim.add_sync_process(process) + + class DataDependentConditionalCircuit(Elaboratable): def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): self.method = Method(i=data_layout(n)) diff --git a/transactron/core/manager.py b/transactron/core/manager.py index 237d1a058..02dc8c381 100644 --- a/transactron/core/manager.py +++ b/transactron/core/manager.py @@ -202,9 +202,9 @@ def rec(transaction: Transaction, source: TransactionOrMethod): @staticmethod def _method_calls( m: Module, method_map: MethodMap - ) -> tuple[Mapping["Method", Sequence[ValueLike]], Mapping["Method", Sequence[ValueLike]]]: - args = defaultdict[Method, list[ValueLike]](list) - runs = defaultdict[Method, list[ValueLike]](list) + ) -> tuple[Mapping["Method", Sequence[MethodStruct]], Mapping["Method", Sequence[Value]]]: + args = defaultdict[Method, list[MethodStruct]](list) + runs = defaultdict[Method, list[Value]](list) for source in method_map.methods_and_transactions: if isinstance(source, Method): @@ -359,8 +359,7 @@ def elaborate(self, platform): raise RuntimeError(f"Single-caller method '{method.name}' called more than once") runs = Cat(method_runs[method]) - for i in OneHotSwitchDynamic(m, runs): - m.d.comb += method.data_in.eq(method_args[method][i]) + m.d.comb += assign(method.data_in, method.combiner(m, method_args[method], runs), fields=AssignType.ALL) if "TRANSACTRON_VERBOSE" in environ: self.print_info(cgr, porder, ccs, method_map) diff --git a/transactron/core/method.py b/transactron/core/method.py index 98fb59f3d..85860f601 100644 --- a/transactron/core/method.py +++ b/transactron/core/method.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from transactron.utils import * from amaranth import * from amaranth import tracer @@ -59,6 +60,7 @@ def __init__( i: MethodLayout = (), o: MethodLayout = (), nonexclusive: bool = False, + combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None, single_caller: bool = False, src_loc: int | SrcLoc = 0, ): @@ -77,6 +79,12 @@ def __init__( transactions in the same clock cycle. If such a situation happens, the method still is executed only once, and each of the callers receive its output. Nonexclusive methods cannot have inputs. + combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg + If `nonexclusive` is true, the combiner function combines the + arguments from multiple calls to this method into a single + argument, which is passed to the method body. The third argument + is a bit vector, whose n-th bit is 1 if the n-th call is active + in a given cycle. single_caller: bool If true, this method is intended to be called from a single transaction. An error will be thrown if called from multiple @@ -86,6 +94,13 @@ def __init__( Alternatively, the source location to use instead of the default. """ super().__init__(src_loc=get_src_loc(src_loc)) + + def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: + ret = Signal(from_method_layout(i)) + for k in OneHotSwitchDynamic(m, runs): + m.d.comb += ret.eq(args[k]) + return ret + self.owner, owner_name = get_caller_class_name(default="$method") self.name = name or tracer.get_var_name(depth=2, default=owner_name) self.ready = Signal(name=self.owned_name + "_ready") @@ -93,10 +108,11 @@ def __init__( self.data_in: MethodStruct = Signal(from_method_layout(i)) self.data_out: MethodStruct = Signal(from_method_layout(o)) self.nonexclusive = nonexclusive + self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner self.single_caller = single_caller self.validate_arguments: Optional[Callable[..., ValueLike]] = None if nonexclusive: - assert len(self.data_in.as_value()) == 0 + assert len(self.data_in.as_value()) == 0 or combiner is not None @property def layout_in(self):