Skip to content

Commit

Permalink
Add pybind11 type caster for JSONValues (#458)
Browse files Browse the repository at this point in the history
* Add pybind11 type-caster for `JSONValues`

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #458
  • Loading branch information
dagardner-nv authored Apr 5, 2024
1 parent f4e6266 commit 5242760
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 2 deletions.
57 changes: 56 additions & 1 deletion python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "pymrc/types.hpp" // for python_map_t & unserializable_handler_fn_t

#include <nlohmann/json.hpp>
#include <pybind11/pytypes.h> // for PYBIND11_EXPORT & pybind11::object
#include <pybind11/pybind11.h> // for PYBIND11_EXPORT, pybind11::object, type_caster

#include <cstddef> // for size_t
#include <string>
Expand Down Expand Up @@ -155,3 +155,58 @@ class PYBIND11_EXPORT JSONValues

#pragma GCC visibility pop
} // namespace mrc::pymrc

/****** Pybind11 caster ******************/

// NOLINTNEXTLINE(modernize-concat-nested-namespaces)
namespace PYBIND11_NAMESPACE {
namespace detail {

template <>
struct type_caster<mrc::pymrc::JSONValues>
{
public:
/**
* This macro establishes a local variable 'value' of type JSONValues
*/
PYBIND11_TYPE_CASTER(mrc::pymrc::JSONValues, _("object"));

/**
* Conversion part 1 (Python->C++): convert a PyObject into JSONValues
* instance or return false upon failure. The second argument
* indicates whether implicit conversions should be applied.
*/
bool load(handle src, bool convert)
{
if (!src)
{
return false;
}

if (src.is_none())
{
value = mrc::pymrc::JSONValues();
}
else
{
value = std::move(mrc::pymrc::JSONValues(pybind11::reinterpret_borrow<pybind11::object>(src)));
}

return true;
}

/**
* Conversion part 2 (C++ -> Python): convert a JSONValues instance into
* a Python object. The second and third arguments are used to
* indicate the return value policy and parent object (for
* ``return_value_policy::reference_internal``) and are generally
* ignored by implicit casters.
*/
static handle cast(mrc::pymrc::JSONValues src, return_value_policy policy, handle parent)
{
return src.to_python().release();
}
};

} // namespace detail
} // namespace PYBIND11_NAMESPACE
17 changes: 17 additions & 0 deletions python/mrc/_pymrc/tests/test_json_values.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ TEST_F(TestJSONValues, ToPythonRootUnserializable)

EXPECT_TRUE(result.equal(py_dec));
EXPECT_TRUE(result.is(py_dec)); // Ensure we stored the object

nlohmann::json expexted_json("**pymrc_placeholder"s);
EXPECT_EQ(j.view_json(), expexted_json);
}

TEST_F(TestJSONValues, ToPythonSimpleDict)
Expand Down Expand Up @@ -542,3 +545,17 @@ TEST_F(TestJSONValues, Stringify)
auto dec_val = mk_decimal("2.2"s);
EXPECT_EQ(JSONValues::stringify(dec_val, "/"s), nlohmann::json("2.2"s));
}

TEST_F(TestJSONValues, CastPyToJSONValues)
{
auto py_dict = mk_py_dict();
auto j = py_dict.cast<JSONValues>();
EXPECT_TRUE(j.to_python().equal(py_dict));
}

TEST_F(TestJSONValues, CastJSONValuesToPy)
{
auto j = JSONValues{mk_json()};
auto py_dict = py::cast(j);
EXPECT_TRUE(py_dict.equal(j.to_python()));
}
11 changes: 10 additions & 1 deletion python/mrc/tests/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,6 +17,8 @@

#include "pymrc/utils.hpp"

#include "pymrc/utilities/json_values.hpp" // for JSONValues

#include "mrc/utils/string_utils.hpp"
#include "mrc/version.hpp"

Expand All @@ -41,6 +43,11 @@ struct RequireGilInDestructor
}
};

pymrc::JSONValues roundtrip_cast(pymrc::JSONValues v)
{
return v;
}

PYBIND11_MODULE(utils, py_mod)
{
py_mod.doc() = R"pbdoc()pbdoc";
Expand All @@ -61,6 +68,8 @@ PYBIND11_MODULE(utils, py_mod)

py::class_<RequireGilInDestructor>(py_mod, "RequireGilInDestructor").def(py::init<>());

py_mod.def("roundtrip_cast", &roundtrip_cast, py::arg("v"));

py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "."
<< mrc_VERSION_PATCH);
}
Expand Down
48 changes: 48 additions & 0 deletions python/tests/test_json_values_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from decimal import Decimal

import pytest

from mrc.tests.utils import roundtrip_cast


def test_docstrings():
expected_docstring = "roundtrip_cast(v: object) -> object"
docstring = inspect.getdoc(roundtrip_cast)
assert docstring == expected_docstring


@pytest.mark.parametrize(
"value",
[
12,
2.4,
RuntimeError("test"),
Decimal("1.2"),
"test", [1, 2, 3], {
"a": 1, "b": 2
}, {
"a": 1, "b": RuntimeError("not serializable")
}, {
"a": 1, "b": Decimal("1.3")
}
],
ids=["int", "float", "exception", "decimal", "str", "list", "dict", "dict_w_exception", "dict_w_decimal"])
def test_cast_roundtrip(value: object):
result = roundtrip_cast(value)
assert result == value

0 comments on commit 5242760

Please sign in to comment.