Skip to content

Commit e4e6fe0

Browse files
rloufbrandonwillard
authored andcommitted
Add a method to specify ExpressionTuple evaluation function
1 parent b5acdc8 commit e4e6fe0

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

etuples/core.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections import deque
55
from collections.abc import Generator, Sequence
6-
from typing import Any, Callable
6+
from typing import Callable
77

88
etuple_repr = reprlib.Repr()
99
etuple_repr.maxstring = 100
@@ -178,8 +178,16 @@ def eval_obj(self):
178178
)
179179
return trampoline_eval(self._eval_step())
180180

181-
def _eval_apply(self, op: Callable, op_args: inspect.BoundArguments) -> Any:
182-
return op(*op_args.args, **op_args.kwargs)
181+
def _eval_apply_fn(self, op: Callable) -> Callable:
182+
"""Return the callable used to evaluate the expression tuple.
183+
184+
The expression tuple's operator can be any `Callable`, i.e. either
185+
a function or an instance of a class that defines `__call__`. In
186+
the latter case, one can evalute the expression tuple using a
187+
method other than `__call__` by overloading this method.
188+
189+
"""
190+
return op
183191

184192
def _eval_step(self):
185193
if len(self._tuple) == 0:
@@ -210,7 +218,7 @@ def _eval_step(self):
210218
evaled_args.append(i)
211219

212220
try:
213-
op_sig = inspect.signature(op)
221+
op_sig = inspect.signature(self._eval_apply_fn(op))
214222
except ValueError:
215223
# This handles some builtin function types
216224
_evaled_obj = op(*(evaled_args + [kw.value for kw in evaled_kwargs]))
@@ -220,7 +228,7 @@ def _eval_step(self):
220228
)
221229
op_args.apply_defaults()
222230

223-
_evaled_obj = self._eval_apply(op, op_args)
231+
_evaled_obj = self._eval_apply_fn(op)(*op_args.args, **op_args.kwargs)
224232

225233
if isinstance(_evaled_obj, Generator):
226234
self._evaled_obj = _evaled_obj

tests/test_core.py

+16
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ def test_ExpressionTuple(capsys):
6868
ExpressionTuple((print, "hi")).eval_obj
6969

7070

71+
def test_eval_apply_fn():
72+
class Add(object):
73+
def __call__(self):
74+
return None
75+
76+
def add(self, x, y):
77+
return x + y
78+
79+
class AddExpressionTuple(ExpressionTuple):
80+
def _eval_apply_fn(self, op):
81+
return op.add
82+
83+
op = Add()
84+
assert AddExpressionTuple((op, 1, 2)).evaled_obj == 3
85+
86+
7187
def test_etuple():
7288
"""Test basic `etuple` functionality."""
7389

0 commit comments

Comments
 (0)