Skip to content

Commit 6d0a79c

Browse files
committed
add support for python providers creating arrays
1 parent 915a250 commit 6d0a79c

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

plugins/python/src/modulewrap.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ namespace {
170170

171171
PyGILRAII gil;
172172

173-
PyObject* result = PyObject_CallFunctionObjArgs(m_callable, (PyObject*)args..., nullptr);
173+
PyObject* result =
174+
PyObject_CallFunctionObjArgs(m_callable, lifeline_transform(args)..., nullptr);
174175

175176
std::string error_msg;
176177
if (!result) {
@@ -604,7 +605,18 @@ namespace {
604605
\
605606
Py_DECREF((PyObject*)pyobj); \
606607
return vec; \
607-
}
608+
} \
609+
\
610+
struct provider_cb_##name : public py_callback<1> { \
611+
std::shared_ptr<std::vector<cpptype>> operator()(data_cell_index const& id) \
612+
{ \
613+
PyGILRAII gil; \
614+
PyObject* arg0 = wrap_dci(id); \
615+
intptr_t pyres = call((intptr_t)arg0); /* decrefs arg0 */ \
616+
auto cres = py_to_##name(pyres); /* decrefs pyres */ \
617+
return cres; \
618+
} \
619+
};
608620

609621
NUMPY_ARRAY_CONVERTER(vint, std::int32_t, NPY_INT32, PyLong_AsLong)
610622
NUMPY_ARRAY_CONVERTER(vuint, std::uint32_t, NPY_UINT32, pylong_or_int_as_ulong)
@@ -1209,27 +1221,33 @@ static PyObject* sc_provide(py_phlex_source* src, PyObject* args, PyObject* kwds
12091221
} else if (out_type == "double") {
12101222
auto* pyc = new provider_cb_double{callable};
12111223
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
1212-
} /* else if (out_type.compare(0, 7, "ndarray") == 0 || out_type.compare(0, 4, "list") == 0) {
1224+
} else if (out_type.compare(0, 7, "ndarray") == 0 || out_type.compare(0, 4, "list") == 0) {
12131225
// TODO: just like for input types, these are hard-coded, but should be handled by
12141226
// an IDL instead.
12151227
std::string_view dtype{out_type.begin() + out_type.rfind('['), out_type.end()};
12161228
if (dtype == "[int32_t]") {
1217-
insert_converter(mod, cname, py_to_vint, out_pq, output);
1229+
auto* pyc = new provider_cb_vint{callable};
1230+
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
12181231
} else if (dtype == "[uint32_t]") {
1219-
insert_converter(mod, cname, py_to_vuint, out_pq, output);
1232+
auto* pyc = new provider_cb_vuint{callable};
1233+
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
12201234
} else if (dtype == "[int64_t]") {
1221-
insert_converter(mod, cname, py_to_vlong, out_pq, output);
1235+
auto* pyc = new provider_cb_vlong{callable};
1236+
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
12221237
} else if (dtype == "[uint64_t]") {
1223-
insert_converter(mod, cname, py_to_vulong, out_pq, output);
1238+
auto* pyc = new provider_cb_vulong{callable};
1239+
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
12241240
} else if (dtype == "[float]") {
1225-
insert_converter(mod, cname, py_to_vfloat, out_pq, output);
1241+
auto* pyc = new provider_cb_vfloat{callable};
1242+
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
12261243
} else if (dtype == "[double]") {
1227-
insert_converter(mod, cname, py_to_vdouble, out_pq, output);
1244+
auto* pyc = new provider_cb_vdouble{callable};
1245+
src->ph_source->provide(functor_name, *pyc).output_product(opq.value());
12281246
} else {
12291247
PyErr_Format(PyExc_TypeError, "unsupported collection output type \"%s\"", out_type.c_str());
12301248
return nullptr;
12311249
}
1232-
} */
1250+
}
12331251
else {
12341252
PyErr_Format(PyExc_TypeError, "unsupported output type \"%s\"", out_type.c_str());
12351253
return nullptr;

test/python/pyprovide.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@
66
"""
77

88
import numpy as np
9+
import numpy.typing as npt
910

1011
from phlex import Variant
1112

12-
specs = ((False, bool, "ib"),
13-
(-42, int, "ii32"),
14-
(42, np.uint32, "iui32"),
15-
(-27, np.int64, "ii64"),
16-
(27, np.uint64, "iui64"),
17-
(42., np.float32, "if32"),
18-
(-42., np.float64, "if64"),
19-
)
13+
_specs = ((-42, np.int32, "ii32"),
14+
(42, np.uint32, "iui32"),
15+
(-27, np.int64, "ii64"),
16+
(27, np.uint64, "iui64"),
17+
(42., np.float32, "if32"),
18+
(-42., np.float64, "if64"),
19+
)
20+
21+
specs = [(False, np.bool_, "ib")]
22+
for x, t, sf in _specs:
23+
specs.append((x, t, sf)) # type: ignore
24+
specs.append((np.array([x], dtype=t), npt.NDArray[t], "v"+sf)) # type: ignore
25+
2026

2127
def PHLEX_REGISTER_PROVIDERS(s, config):
2228
"""Register python providers for all supported types.
@@ -62,6 +68,6 @@ def a(y):
6268
return a
6369

6470
for x, t, sf in specs:
65-
f = Variant(new_a(x), {"y": t, "return": None}, "py"+t.__name__)
71+
f = Variant(new_a(x), {"y": t, "return": None}, sf+t.__name__)
6672
m.observe(f, input_family=[{"creator": "input_"+sf, "layer": "event"}])
6773

0 commit comments

Comments
 (0)