Skip to content

Commit

Permalink
pickle compatibility ensured
Browse files Browse the repository at this point in the history
  • Loading branch information
dg-pb committed Jun 23, 2024
1 parent 7957a97 commit 8aaee6a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 50 deletions.
26 changes: 6 additions & 20 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,31 +392,17 @@ def __call__(self, /, *args, **keywords):
__repr__ = recursive_repr()(_partial_repr)

def __reduce__(self):
state = (
self.func,
self.args,
self.keywords or None,
self.placeholder_count,
self.__dict__ or None
)
return type(self), (self.func,), state
return type(self), (self.func,), (self.func, self.args,
self.keywords or None, self.__dict__ or None)

def __setstate__(self, state):
if not isinstance(state, tuple):
raise TypeError("argument to __setstate__ must be a tuple")
state_len = len(state)
if state_len == 4:
# Support pre-placeholder de-serialization
func, args, kwds, namespace = state
phcount = 0
elif state_len == 5:
func, args, kwds, phcount, namespace = state
else:
raise TypeError(f"expected 4 or 5 items in state, got {state_len}")

if len(state) != 4:
raise TypeError(f"expected 4 items in state, got {len(state)}")
func, args, kwds, namespace = state
if (not callable(func) or not isinstance(args, tuple) or
(kwds is not None and not isinstance(kwds, dict)) or
not isinstance(phcount, int) or
(namespace is not None and not isinstance(namespace, dict))):
raise TypeError("invalid partial state")

Expand All @@ -432,7 +418,7 @@ def __setstate__(self, state):
self.func = func
self.args = args
self.keywords = kwds
self.placeholder_count = phcount
self.placeholder_count = args.count(Placeholder)

try:
from _functools import partial, Placeholder
Expand Down
6 changes: 1 addition & 5 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,9 @@ def test_setstate(self):
self.assertEqual(f(), ((), {}))

# Set State with placeholders
f = self.partial(signature)
f.__setstate__((capture, (1,), dict(a=10), 0, dict(attr=[])))
self.assertEqual(signature(f), (capture, (1,), dict(a=10), dict(attr=[])))

PH = self.module.Placeholder
f = self.partial(signature)
f.__setstate__((capture, (PH, 1), dict(a=10), 1, dict(attr=[])))
f.__setstate__((capture, (PH, 1), dict(a=10), dict(attr=[])))
self.assertEqual(signature(f), (capture, (PH, 1), dict(a=10), dict(attr=[])))
with self.assertRaises(TypeError):
self.assertEqual(f(), (PH, 1), dict(a=10))
Expand Down
51 changes: 26 additions & 25 deletions Modules/_functoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
pto->fn = Py_NewRef(func);

pto->placeholder = state->placeholder;
if (Py_Is(PyTuple_GET_ITEM(args, new_nargs), pto->placeholder)) {
if (new_nargs && Py_Is(PyTuple_GET_ITEM(args, new_nargs), pto->placeholder)) {
PyErr_SetString(PyExc_TypeError,
"trailing Placeholders are not allowed");
return NULL;
Expand All @@ -204,15 +204,11 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
return NULL;
}

Py_ssize_t phcount = 0;
PyObject *item;
/* Count placeholders */
if (new_nargs > 1) {
for (Py_ssize_t i = 0; i < new_nargs - 1; i++) {
item = PyTuple_GET_ITEM(new_args, i);
if (Py_Is(item, pto->placeholder)) {
phcount++;
}
Py_ssize_t phcount = 0;
for (Py_ssize_t i = 0; i < new_nargs - 1; i++) {
if (Py_Is(PyTuple_GET_ITEM(new_args, i), pto->placeholder)) {
phcount++;
}
}
/* merge args with args of `func` which is `partial` */
Expand All @@ -222,6 +218,7 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
if (new_nargs > pto_phcount) {
tot_nargs += new_nargs - pto_phcount;
}
PyObject *item;
PyObject *tot_args = PyTuple_New(tot_nargs);
for (Py_ssize_t i = 0, j = 0; i < tot_nargs; i++) {
if (i < npargs) {
Expand Down Expand Up @@ -635,37 +632,27 @@ partial_repr(partialobject *pto)
static PyObject *
partial_reduce(partialobject *pto, PyObject *unused)
{
return Py_BuildValue("O(O)(OOOnO)", Py_TYPE(pto), pto->fn, pto->fn,
pto->args, pto->kw, pto->phcount,
return Py_BuildValue("O(O)(OOOO)", Py_TYPE(pto), pto->fn, pto->fn,
pto->args, pto->kw,
pto->dict ? pto->dict : Py_None);
}

static PyObject *
partial_setstate(partialobject *pto, PyObject *state)
{
PyObject *fn, *fnargs, *kw, *dict;
Py_ssize_t phcount;

if (!PyTuple_Check(state)) {
PyErr_SetString(PyExc_TypeError, "invalid partial state");
return NULL;
}
Py_ssize_t state_len = PyTuple_GET_SIZE(state);
int parse_rtrn;
if (state_len == 4) {
/* pre-placeholder support */
parse_rtrn = PyArg_ParseTuple(state, "OOOO", &fn, &fnargs, &kw, &dict);
phcount = 0;
}
else if (state_len == 5) {
parse_rtrn = PyArg_ParseTuple(state, "OOOnO", &fn, &fnargs,
&kw, &phcount, &dict);
}
else {
if (state_len != 4) {
PyErr_Format(PyExc_TypeError,
"expected 4 or 5 items in state, got %zd", state_len);
"expected 4 items in state, got %zd", state_len);
return NULL;
}
if (!parse_rtrn ||
if (!PyArg_ParseTuple(state, "OOOO", &fn, &fnargs, &kw, &dict) ||
!PyCallable_Check(fn) ||
!PyTuple_Check(fnargs) ||
(kw != Py_None && !PyDict_Check(kw)))
Expand All @@ -674,6 +661,20 @@ partial_setstate(partialobject *pto, PyObject *state)
return NULL;
}

Py_ssize_t nargs = PyTuple_GET_SIZE(fnargs);
if (nargs && Py_Is(PyTuple_GET_ITEM(fnargs, nargs - 1), pto->placeholder)) {
PyErr_SetString(PyExc_TypeError,
"trailing Placeholders are not allowed");
return NULL;
}
/* Count placeholders */
Py_ssize_t phcount = 0;
for (Py_ssize_t i = 0; i < nargs - 1; i++) {
if (Py_Is(PyTuple_GET_ITEM(fnargs, i), pto->placeholder)) {
phcount++;
}
}

if(!PyTuple_CheckExact(fnargs))
fnargs = PySequence_Tuple(fnargs);
else
Expand Down

0 comments on commit 8aaee6a

Please sign in to comment.