Skip to content
Open
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
edc21c7
chore(pre-commit): update pre-commit hooks
XuehaiPan Nov 29, 2025
d5a1cc4
test: reorgainize test files
XuehaiPan Oct 8, 2025
8bd5fde
test: add subinterpreters tests
XuehaiPan Oct 8, 2025
23abd08
chore: reorder dependencies
XuehaiPan Nov 29, 2025
65a8fce
test: add import tests
XuehaiPan Nov 29, 2025
261729a
chore(workflows): update find options
XuehaiPan Nov 29, 2025
6b9b38a
feat: set `py::multiple_interpreters::shared_gil()`
XuehaiPan Nov 30, 2025
e4f2019
feat: set `py::multiple_interpreters::per_interpreter_gil()`
XuehaiPan Nov 30, 2025
eec4602
chore: disable subinterpreters for pre-3.14
XuehaiPan Dec 1, 2025
fade81a
chore(pre-commit): [pre-commit.ci] autoupdate
pre-commit-ci[bot] Dec 1, 2025
e7a2a89
refactor: simplify per-interpreter registry
XuehaiPan Dec 1, 2025
1fd07c8
fix: fix collections type for subinterpreters
XuehaiPan Dec 1, 2025
209801d
style: add `[[nodiscard]]` attributes
XuehaiPan Dec 2, 2025
02648ab
chore(pre-commit): update pre-commit hooks
XuehaiPan Dec 6, 2025
a32c672
chore: sync changes
XuehaiPan Dec 6, 2025
c026062
chore: sync changes
XuehaiPan Dec 6, 2025
3ba5f29
chore: sync changes
XuehaiPan Dec 6, 2025
278d920
Merge branch 'pre-commit-ci-update-config' into subinterpreters
XuehaiPan Dec 6, 2025
51f6878
Merge branch 'main' into subinterpreters
XuehaiPan Dec 8, 2025
e276ec3
chore: set macro `OPTREE_HAS_SUBINTERPRETER_SUPPORT`
XuehaiPan Dec 8, 2025
86aa6f6
fix: fix find command
XuehaiPan Dec 8, 2025
f955df0
refactor: change pointers to references
XuehaiPan Dec 8, 2025
e2463e2
test: add import tests
XuehaiPan Dec 8, 2025
be1fb8e
fix: do not use `py::gil_safe_call_once_and_store` for subinterpreters
XuehaiPan Dec 8, 2025
8d5c8bf
Merge branch 'main' into subinterpreters
XuehaiPan Dec 8, 2025
3495e72
chore: reorgainize code
XuehaiPan Dec 8, 2025
2278070
feat: add function `get_main_interpreter_id()`
XuehaiPan Dec 8, 2025
8c36df2
chore: reorgainize code
XuehaiPan Dec 10, 2025
b7d9173
fix: fix docs dependency resolving
XuehaiPan Dec 10, 2025
d25f6dc
fix(workflows/lint): fix docs dependency resolving
XuehaiPan Dec 10, 2025
af30820
Merge branch 'main' into subinterpreters
XuehaiPan Dec 10, 2025
0801449
fix(workflows/lint): fix docs dependency resolving
XuehaiPan Dec 10, 2025
8ac7a96
Merge branch 'main' into subinterpreters
XuehaiPan Dec 10, 2025
e119a2f
chore: update nightly pybind11 url
XuehaiPan Dec 14, 2025
5dcc2f8
feat: improve sanity check error messages
XuehaiPan Dec 14, 2025
1c43acc
revert
XuehaiPan Dec 14, 2025
90afec2
update
XuehaiPan Dec 14, 2025
ce661a2
chore: split ci jobs
XuehaiPan Dec 15, 2025
afae64c
fix: fix repr for exception
XuehaiPan Dec 15, 2025
5c3ba09
Merge branch 'main' into subinterpreters
XuehaiPan Dec 17, 2025
5fb4769
test: skip failed tests
XuehaiPan Dec 17, 2025
e486981
test: set no-cov for subinterpreter tests
XuehaiPan Dec 17, 2025
d89e120
test: set env for subprocess
XuehaiPan Dec 18, 2025
5efad7e
chore: split tests
XuehaiPan Dec 18, 2025
bb0a4d9
chore(pre-commit): update pre-commit hooks
XuehaiPan Dec 19, 2025
b7cb1cf
test: enable subinterpreter tests
XuehaiPan Dec 19, 2025
3fb200e
Merge remote-tracking branch 'upstream/main' into subinterpreters
XuehaiPan Dec 21, 2025
f7abc85
chore: update nightly remote
XuehaiPan Dec 25, 2025
ea8cc2c
chore: add more build time meta
XuehaiPan Dec 25, 2025
2935452
chore: update nightly remote
XuehaiPan Dec 26, 2025
fb08dec
chore: update test
XuehaiPan Dec 27, 2025
0d18d19
chore: update test
XuehaiPan Dec 27, 2025
3996bb0
chore: update macros
XuehaiPan Dec 27, 2025
c420862
chore: remove `Py_Get_ID`
XuehaiPan Dec 27, 2025
56d991f
chore: cleanup dict order namespaces
XuehaiPan Dec 28, 2025
2e5e53f
fix: fix concurrency issue
XuehaiPan Dec 28, 2025
e39789d
chore: add `[[likely]]` attribute
XuehaiPan Dec 28, 2025
389e17c
fix: fix PyPy
XuehaiPan Dec 28, 2025
302db6a
chore: use simple GIL
XuehaiPan Dec 28, 2025
7084688
test: update test
XuehaiPan Dec 28, 2025
1deda32
test: update test timeout
XuehaiPan Dec 29, 2025
37f22d9
chore: update macros
XuehaiPan Dec 29, 2025
a708f32
refactor: move dict order registry to `PyTreeTypeRegistry`
XuehaiPan Dec 31, 2025
99cc0b1
chore: handle refcount
XuehaiPan Dec 31, 2025
60b5afc
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Dec 31, 2025
0907ca9
Merge remote-tracking branch 'upstream/main' into subinterpreters
XuehaiPan Jan 3, 2026
76ffa62
chore: trigger CI
XuehaiPan Jan 7, 2026
58bebe8
chore: trigger CI
XuehaiPan Jan 8, 2026
4e9ee5c
chore: trigger CI
XuehaiPan Jan 11, 2026
0ed74e3
chore(pre-commit): update pre-commit hooks
XuehaiPan Jan 11, 2026
a68d686
chore: trigger CI
XuehaiPan Jan 15, 2026
cb83826
chore: trigger CI
XuehaiPan Jan 16, 2026
c1ca979
chore: set macro `OPTREE_HAS_READ_WRITE_LOCK`
XuehaiPan Jan 16, 2026
70529ef
chore(pre-commit): update pre-commit hooks
XuehaiPan Jan 16, 2026
7f9abfe
chore: update typos
XuehaiPan Jan 21, 2026
15804c7
Revert "chore: update typos"
XuehaiPan Jan 21, 2026
09cfe3a
Reapply "chore: update typos"
XuehaiPan Jan 21, 2026
5408166
chore: update remote URL
XuehaiPan Jan 22, 2026
c36e6bc
test: enable Python 3.14 tests for integrations
XuehaiPan Jan 22, 2026
35ec40c
chore: trigger CI
XuehaiPan Jan 22, 2026
229e14e
feat: show coredump in lint
XuehaiPan Jan 22, 2026
0f41d8d
chore: trigger CI
XuehaiPan Jan 22, 2026
9d671a2
chore(pre-commit): update pre-commit hooks
XuehaiPan Jan 23, 2026
28867f0
chore: trigger CI
XuehaiPan Jan 24, 2026
98fec43
chore: set remote packages
XuehaiPan Jan 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/tests-with-pydebug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,13 @@ jobs:
"--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml"
"--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml"
)
make test PYTESTOPTS="${PYTESTOPTS[*]}"

if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then
make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov"
make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'"
else
make test PYTESTOPTS="${PYTESTOPTS[*]}"
fi

CORE_DUMP_FILES="$(
find . -type d -path "./venv" -prune \
Expand Down
94 changes: 57 additions & 37 deletions include/optree/pymacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#pragma once

#include <stdexcept> // std::runtime_error

#include <Python.h>

#include <pybind11/pybind11.h>
Expand All @@ -32,6 +34,15 @@ limitations under the License.
// NOLINTNEXTLINE[bugprone-macro-parentheses]
#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0))

#if !defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \
(PYBIND11_VERSION_HEX >= 0x030002A0 /* pybind11 3.0.2.a0 */) && \
(defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \
NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT))
# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1
#else
# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT
#endif

namespace py = pybind11;

#if !defined(Py_ALWAYS_INLINE)
Expand Down Expand Up @@ -60,40 +71,49 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept {
}
#define Py_IsConstant(x) Py_IsConstant(x)

#define Py_Declare_ID(name) \
inline namespace { \
[[nodiscard]] inline PyObject *Py_ID_##name() { \
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<PyObject *> storage; \
return storage \
.call_once_and_store_result([]() -> PyObject * { \
PyObject * const ptr = PyUnicode_InternFromString(#name); \
if (ptr == nullptr) [[unlikely]] { \
throw py::error_already_set(); \
} \
Py_INCREF(ptr); /* leak a reference on purpose */ \
return ptr; \
}) \
.get_stored(); \
} \
} // namespace

#define Py_Get_ID(name) (::Py_ID_##name())

Py_Declare_ID(optree);
Py_Declare_ID(__main__); // __main__
Py_Declare_ID(__module__); // type.__module__
Py_Declare_ID(__qualname__); // type.__qualname__
Py_Declare_ID(__name__); // type.__name__
Py_Declare_ID(sort); // list.sort
Py_Declare_ID(copy); // dict.copy
Py_Declare_ID(OrderedDict); // OrderedDict
Py_Declare_ID(defaultdict); // defaultdict
Py_Declare_ID(deque); // deque
Py_Declare_ID(default_factory); // defaultdict.default_factory
Py_Declare_ID(maxlen); // deque.maxlen
Py_Declare_ID(_fields); // namedtuple._fields
Py_Declare_ID(_make); // namedtuple._make
Py_Declare_ID(_asdict); // namedtuple._asdict
Py_Declare_ID(n_fields); // structseq.n_fields
Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields
Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields
using interpid_t = decltype(PyInterpreterState_GetID(nullptr));

#if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \
NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)

[[nodiscard]] inline bool IsCurrentPyInterpreterMain() {
return PyInterpreterState_Get() == PyInterpreterState_Main();
}

[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() {
PyInterpreterState *interp = PyInterpreterState_Get();
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
if (interp == nullptr) [[unlikely]] {
throw std::runtime_error("Failed to get the current Python interpreter state.");
}
const interpid_t interpid = PyInterpreterState_GetID(interp);
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
return interpid;
}

[[nodiscard]] inline interpid_t GetMainPyInterpreterID() {
PyInterpreterState *interp = PyInterpreterState_Main();
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
if (interp == nullptr) [[unlikely]] {
throw std::runtime_error("Failed to get the main Python interpreter state.");
}
const interpid_t interpid = PyInterpreterState_GetID(interp);
if (PyErr_Occurred() != nullptr) [[unlikely]] {
throw py::error_already_set();
}
return interpid;
}

#else

[[nodiscard]] inline bool IsCurrentPyInterpreterMain() noexcept { return true; }
[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() noexcept { return 0; }
[[nodiscard]] inline interpid_t GetMainPyInterpreterID() noexcept { return 0; }

#endif
37 changes: 23 additions & 14 deletions include/optree/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ inline bool IsNamedTupleClassImpl(const py::handle &type) {
// We can only identify namedtuples heuristically, here by the presence of a _fields attribute.
if (PyType_FastSubclass(reinterpret_cast<PyTypeObject *>(type.ptr()),
Py_TPFLAGS_TUPLE_SUBCLASS)) [[unlikely]] {
if (PyObject * const _fields = PyObject_GetAttr(type.ptr(), Py_Get_ID(_fields)))
[[unlikely]] {
if (PyObject * const _fields = PyObject_GetAttrString(type.ptr(), "_fields")) [[unlikely]] {
bool fields_ok = static_cast<bool>(PyTuple_CheckExact(_fields));
if (fields_ok) [[likely]] {
for (const auto &field : py::reinterpret_borrow<py::tuple>(_fields)) {
Expand All @@ -232,8 +231,9 @@ inline bool IsNamedTupleClassImpl(const py::handle &type) {
Py_DECREF(_fields);
if (fields_ok) [[likely]] {
// NOLINTNEXTLINE[readability-use-anyofallof]
for (PyObject * const name : {Py_Get_ID(_make), Py_Get_ID(_asdict)}) {
if (PyObject * const attr = PyObject_GetAttr(type.ptr(), name)) [[likely]] {
for (const char * const name : {"_make", "_asdict"}) {
if (PyObject * const attr = PyObject_GetAttrString(type.ptr(), name))
[[likely]] {
const bool result = static_cast<bool>(PyCallable_Check(attr));
Py_DECREF(attr);
if (!result) [[unlikely]] {
Expand Down Expand Up @@ -261,6 +261,7 @@ inline bool IsNamedTupleClass(const py::handle &type) {
static read_write_mutex mutex{};

{
const py::gil_scoped_release_simple gil_release{};
const scoped_read_lock lock{mutex};
const auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
Expand All @@ -270,8 +271,10 @@ inline bool IsNamedTupleClass(const py::handle &type) {

const bool result = EVALUATE_WITH_LOCK_HELD(IsNamedTupleClassImpl(type), type);
{
const py::gil_scoped_release_simple gil_release{};
const scoped_write_lock lock{mutex};
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
const py::gil_scoped_acquire_simple gil_acquire{};
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
const scoped_write_lock lock{mutex};
Expand Down Expand Up @@ -311,7 +314,7 @@ inline py::tuple NamedTupleGetFields(const py::handle &object) {
PyRepr(object) + ".");
}
}
return EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(_fields)), type);
return EVALUATE_WITH_LOCK_HELD(py::getattr(type, "_fields"), type);
}

inline bool IsStructSequenceClassImpl(const py::handle &type) {
Expand All @@ -325,9 +328,8 @@ inline bool IsStructSequenceClassImpl(const py::handle &type) {
PyTuple_GET_ITEM(type_object->tp_bases, 0) == reinterpret_cast<PyObject *>(&PyTuple_Type))
[[unlikely]] {
// NOLINTNEXTLINE[readability-use-anyofallof]
for (PyObject * const name :
{Py_Get_ID(n_fields), Py_Get_ID(n_sequence_fields), Py_Get_ID(n_unnamed_fields)}) {
if (PyObject * const attr = PyObject_GetAttr(type.ptr(), name)) [[unlikely]] {
for (const char * const name : {"n_fields", "n_sequence_fields", "n_unnamed_fields"}) {
if (PyObject * const attr = PyObject_GetAttrString(type.ptr(), name)) [[unlikely]] {
const bool result = static_cast<bool>(PyLong_CheckExact(attr));
Py_DECREF(attr);
if (!result) [[unlikely]] {
Expand Down Expand Up @@ -364,6 +366,7 @@ inline bool IsStructSequenceClass(const py::handle &type) {
static read_write_mutex mutex{};

{
const py::gil_scoped_release_simple gil_release{};
const scoped_read_lock lock{mutex};
const auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
Expand All @@ -373,8 +376,10 @@ inline bool IsStructSequenceClass(const py::handle &type) {

const bool result = EVALUATE_WITH_LOCK_HELD(IsStructSequenceClassImpl(type), type);
{
const py::gil_scoped_release_simple gil_release{};
const scoped_write_lock lock{mutex};
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
const py::gil_scoped_acquire_simple gil_acquire{};
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
const scoped_write_lock lock{mutex};
Expand Down Expand Up @@ -418,7 +423,7 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle &type) {
return py::tuple{fields};
#else
const auto n_sequence_fields = thread_safe_cast<py::ssize_t>(
EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(n_sequence_fields)), type));
EVALUATE_WITH_LOCK_HELD(py::getattr(type, "n_sequence_fields"), type));
const auto * const members = reinterpret_cast<PyTypeObject *>(type.ptr())->tp_members;
py::tuple fields{n_sequence_fields};
for (py::ssize_t i = 0; i < n_sequence_fields; ++i) {
Expand Down Expand Up @@ -447,17 +452,21 @@ inline py::tuple StructSequenceGetFields(const py::handle &object) {
static read_write_mutex mutex{};

{
const py::gil_scoped_release_simple gil_release{};
const scoped_read_lock lock{mutex};
const auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
const py::gil_scoped_acquire_simple gil_acquire{};
return py::reinterpret_borrow<py::tuple>(it->second);
}
}

const py::tuple fields = EVALUATE_WITH_LOCK_HELD(StructSequenceGetFieldsImpl(type), type);
{
const py::gil_scoped_release_simple gil_release{};
const scoped_write_lock lock{mutex};
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
const py::gil_scoped_acquire_simple gil_acquire{};
cache.emplace(type, fields);
fields.inc_ref();
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
Expand Down Expand Up @@ -489,15 +498,15 @@ inline void TotalOrderSort(py::list &list) { // NOLINT[runtime/references]
// Sort with `(f'{obj.__class__.__module__}.{obj.__class__.__qualname__}', obj)`
const auto sort_key_fn = py::cpp_function([](const py::object &obj) -> py::tuple {
const py::handle cls = py::type::handle_of(obj);
const py::str qualname{EVALUATE_WITH_LOCK_HELD(
PyStr(py::getattr(cls, Py_Get_ID(__module__))) + "." +
PyStr(py::getattr(cls, Py_Get_ID(__qualname__))),
cls)};
const py::str qualname{
EVALUATE_WITH_LOCK_HELD(PyStr(py::getattr(cls, "__module__")) + "." +
PyStr(py::getattr(cls, "__qualname__")),
cls)};
return py::make_tuple(qualname, obj);
});
{
const scoped_critical_section cs{list};
py::getattr(list, Py_Get_ID(sort))(py::arg("key") = sort_key_fn);
py::getattr(list, "sort")(py::arg("key") = sort_key_fn);
}
} catch (py::error_already_set &ex2) {
if (ex2.matches(PyExc_TypeError)) [[likely]] {
Expand Down
22 changes: 22 additions & 0 deletions include/optree/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.

#include "optree/exceptions.h"
#include "optree/hashing.h"
#include "optree/pymacros.h"
#include "optree/synchronization.h"

namespace optree {
Expand Down Expand Up @@ -141,6 +142,25 @@ class PyTreeTypeRegistry {
return count1;
}

// Get the number of alive interpreters that have seen the registry.
[[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() {
const scoped_read_lock lock{sm_mutex};
return py::ssize_t_cast(sm_alive_interpids.size());
}

// Get the number of interpreters that have seen the registry.
[[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() {
const scoped_read_lock lock{sm_mutex};
return sm_num_interpreters_seen;
}

// Get the IDs of alive interpreters that have seen the registry.
[[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set<interpid_t>
GetAliveInterpreterIDs() {
const scoped_read_lock lock{sm_mutex};
return sm_alive_interpids;
}

friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references]

private:
Expand Down Expand Up @@ -173,7 +193,9 @@ class PyTreeTypeRegistry {
NamedRegistrationsMap m_named_registrations{};
BuiltinsTypesSet m_builtins_types{};

static inline std::unordered_set<interpid_t> sm_alive_interpids{};
static inline read_write_mutex sm_mutex{};
static inline ssize_t sm_num_interpreters_seen = 0;
};

} // namespace optree
28 changes: 15 additions & 13 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.
#include <thread> // std::thread::id
#include <tuple> // std::tuple
#include <unordered_set> // std::unordered_set
#include <utility> // std::pair
#include <utility> // std::pair, std::make_pair
#include <vector> // std::vector

#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -263,28 +263,31 @@ class PyTreeSpec {
[[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered(
const std::string &registry_namespace,
const bool &inherit_global_namespace = true) {
const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex};
const scoped_read_lock lock{sm_dict_order_mutex};

return (sm_is_dict_insertion_ordered.find(registry_namespace) !=
sm_is_dict_insertion_ordered.end()) ||
(inherit_global_namespace &&
sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end());
const auto interpid = GetCurrentPyInterpreterID();
const auto &namespaces = sm_dict_insertion_ordered_namespaces;
return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) ||
(inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end());
}

// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered(
const bool &mode,
const std::string &registry_namespace) {
const scoped_write_lock lock{sm_is_dict_insertion_ordered_mutex};
const scoped_write_lock lock{sm_dict_order_mutex};

const auto interpid = GetCurrentPyInterpreterID();
const auto key = std::make_pair(interpid, registry_namespace);
if (mode) [[likely]] {
sm_is_dict_insertion_ordered.insert(registry_namespace);
sm_dict_insertion_ordered_namespaces.insert(key);
} else [[unlikely]] {
sm_is_dict_insertion_ordered.erase(registry_namespace);
sm_dict_insertion_ordered_namespaces.erase(key);
}
}

friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references]
friend class PyTreeTypeRegistry;

private:
using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr;
Expand Down Expand Up @@ -426,8 +429,9 @@ class PyTreeSpec {

// A set of namespaces that preserve the insertion order of the dictionary keys during
// flattening.
static inline std::unordered_set<std::string> sm_is_dict_insertion_ordered{};
static inline read_write_mutex sm_is_dict_insertion_ordered_mutex{};
static inline std::unordered_set<std::pair<interpid_t, std::string>>
sm_dict_insertion_ordered_namespaces{};
static inline read_write_mutex sm_dict_order_mutex{};
};

class PyTreeIter {
Expand Down Expand Up @@ -464,9 +468,7 @@ class PyTreeIter {
const bool m_none_is_leaf;
const std::string m_namespace;
const bool m_is_dict_insertion_ordered;
#if defined(Py_GIL_DISABLED)
mutable mutex m_mutex{};
#endif

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
Expand Down
9 changes: 9 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ Py_DEBUG: Final[bool]
Py_GIL_DISABLED: Final[bool]
PYBIND11_VERSION_HEX: Final[int]
PYBIND11_INTERNALS_VERSION: Final[int]
PYBIND11_INTERNALS_ID: Final[str]
PYBIND11_MODULE_LOCAL_ID: Final[str]
PYBIND11_HAS_NATIVE_ENUM: Final[bool]
PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool]
PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool]
GLIBCXX_USE_CXX11_ABI: Final[bool]
OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool]

@final
class InternalError(SystemError): ...
Expand Down Expand Up @@ -214,3 +217,9 @@ def set_dict_insertion_ordered(
namespace: str = '',
) -> None: ...
def get_registry_size(namespace: str | None = None) -> int: ...
def get_num_interpreters_seen() -> int: ...
def get_num_interpreters_alive() -> int: ...
def get_alive_interpreter_ids() -> set[int]: ...
def is_current_interpreter_main() -> bool: ...
def get_current_interpreter_id() -> int: ...
def get_main_interpreter_id() -> int: ...
Loading
Loading