Skip to content

Commit

Permalink
Add tests for JSONValues python cast
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Apr 4, 2024
1 parent a8052c7 commit e439d5b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/mrc/tests/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include "pymrc/utils.hpp"

#include "pymrc/utilities/json_values.hpp" // for JSONValues

#include "mrc/utils/string_utils.hpp"
#include "mrc/version.hpp"

Expand All @@ -41,6 +43,11 @@ struct RequireGilInDestructor
}
};

pymrc::JSONValues roundtrip_cast(pymrc::JSONValues v)
{
return v;
}

PYBIND11_MODULE(utils, py_mod)
{
py_mod.doc() = R"pbdoc()pbdoc";
Expand All @@ -61,6 +68,8 @@ PYBIND11_MODULE(utils, py_mod)

py::class_<RequireGilInDestructor>(py_mod, "RequireGilInDestructor").def(py::init<>());

py_mod.def("roundtrip_cast", &roundtrip_cast, py::arg("v"));

py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "."
<< mrc_VERSION_PATCH);
}
Expand Down
48 changes: 48 additions & 0 deletions python/tests/test_json_values_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from decimal import Decimal

import pytest

from mrc.tests.utils import roundtrip_cast


def test_docstrings():
expected_docstring = "roundtrip_cast(v: object) -> object"
docstring = inspect.getdoc(roundtrip_cast)
assert docstring == expected_docstring


@pytest.mark.parametrize(
"value",
[
12,
2.4,
RuntimeError("test"),
Decimal("1.2"),
"test", [1, 2, 3], {
"a": 1, "b": 2
}, {
"a": 1, "b": RuntimeError("not serializable")
}, {
"a": 1, "b": Decimal("1.3")
}
],
ids=["int", "float", "exception", "decimal", "str", "list", "dict", "dict_w_exception", "dict_w_decimal"])
def test_cast_roundtrip(value: object):
result = roundtrip_cast(value)
assert result == value

0 comments on commit e439d5b

Please sign in to comment.