Skip to content

Commit

Permalink
For module configs add an unserializable handler to cast to null, res…
Browse files Browse the repository at this point in the history
…tores previous behavior prior to PR #451
  • Loading branch information
dagardner-nv committed Apr 4, 2024
1 parent 86875b9 commit 7c500a0
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 12 deletions.
19 changes: 17 additions & 2 deletions python/mrc/_pymrc/include/pymrc/types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,9 +21,12 @@

#include "mrc/segment/object.hpp"

#include <nlohmann/json_fwd.hpp>
#include <rxcpp/rx.hpp>

#include <functional>
#include <functional> // for function
#include <map>
#include <string>

namespace mrc::pymrc {

Expand All @@ -37,4 +40,16 @@ using PyNode = mrc::segment::ObjectProperties;
using PyObjectOperateFn = std::function<PyObjectObservable(PyObjectObservable source)>;
// NOLINTEND(readability-identifier-naming)

using python_map_t = std::map<std::string, pybind11::object>;

/**
* @brief Unserializable handler function type, invoked by `cast_from_pyobject` when an object cannot be serialized to
* JSON. Implementations should return a valid json object, or throw an exception if the object cannot be serialized.
* @param source : pybind11 object
* @param path : string json path to object
* @return nlohmann::json.
*/
using unserializable_handler_fn_t =
std::function<nlohmann::json(const pybind11::object& /* source*/, const std::string& /* path */)>;

} // namespace mrc::pymrc
19 changes: 19 additions & 0 deletions python/mrc/_pymrc/include/pymrc/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#pragma once

#include "pymrc/types.hpp"

#include <nlohmann/json_fwd.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
Expand All @@ -31,8 +33,25 @@ namespace mrc::pymrc {
#pragma GCC visibility push(default)

pybind11::object cast_from_json(const nlohmann::json& source);

/**
* @brief Convert a pybind11 object to a JSON object. If the object cannot be serialized, a pybind11::type_error
* exception be thrown.
* @param source : pybind11 object
* @return nlohmann::json.
*/
nlohmann::json cast_from_pyobject(const pybind11::object& source);

/**
* @brief Convert a pybind11 object to a JSON object. If the object cannot be serialized, the unserializable_handler_fn
* will be invoked to handle the object.
* @param source : pybind11 object
* @param unserializable_handler_fn : unserializable_handler_fn_t
* @return nlohmann::json.
*/
nlohmann::json cast_from_pyobject(const pybind11::object& source,
unserializable_handler_fn_t unserializable_handler_fn);

void import_module_object(pybind11::module_&, const std::string&, const std::string&);
void import_module_object(pybind11::module_& dest, const pybind11::module_& mod);

Expand Down
5 changes: 4 additions & 1 deletion python/mrc/_pymrc/src/module_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ pybind11::cpp_function ModuleRegistryProxy::get_module_constructor(const std::st
{
auto fn_constructor = modules::ModuleRegistry::get_module_constructor(name, registry_namespace);
auto py_module_wrapper = [fn_constructor](std::string module_name, pybind11::dict config) {
auto json_config = cast_from_pyobject(config);
auto json_config = cast_from_pyobject(config, [](const pybind11::object&, const std::string& path) {
DVLOG(10) << "Could not serialize object at path: " << path;
return nlohmann::json(); // Return a null json object if we can't convert
});
return fn_constructor(std::move(module_name), std::move(json_config));
};

Expand Down
5 changes: 4 additions & 1 deletion python/mrc/_pymrc/src/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,10 @@ std::shared_ptr<mrc::modules::SegmentModule> BuilderProxy::load_module_from_regi
std::string module_name,
py::dict config)
{
auto json_config = cast_from_pyobject(config);
auto json_config = cast_from_pyobject(config, [](const py::object&, const std::string& path) {
DVLOG(10) << "Could not serialize object at path: " << path;
return nlohmann::json(); // Return a null json object if we can't convert
});

return self.load_module_from_registry(module_id, registry_namespace, std::move(module_name), std::move(json_config));
}
Expand Down
29 changes: 24 additions & 5 deletions python/mrc/_pymrc/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
#include <pyerrors.h>
#include <warnings.h>

#include <functional> // for function
#include <sstream>
#include <string>
#include <utility>

namespace mrc::pymrc {

namespace py = pybind11;

using nlohmann::json;
Expand Down Expand Up @@ -139,14 +139,17 @@ py::object cast_from_json(const json& source)
// throw std::runtime_error("Unsupported conversion type.");
}

json cast_from_pyobject_impl(const py::object& source, const std::string& parent_path = "")
json cast_from_pyobject_impl(const py::object& source,
unserializable_handler_fn_t unserializable_handler_fn,
const std::string& parent_path = "")
{
// Dont return via initializer list with JSON. It performs type deduction and gives different results
// NOLINTBEGIN(modernize-return-braced-init-list)
if (source.is_none())
{
return json();
}

if (py::isinstance<py::dict>(source))
{
const auto py_dict = source.cast<py::dict>();
Expand All @@ -155,34 +158,40 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent
{
std::string key{p.first.cast<std::string>()};
std::string path{parent_path + "/" + key};
json_obj[key] = cast_from_pyobject_impl(p.second.cast<py::object>(), path);
json_obj[key] = cast_from_pyobject_impl(p.second.cast<py::object>(), unserializable_handler_fn, path);
}

return json_obj;
}

if (py::isinstance<py::list>(source) || py::isinstance<py::tuple>(source))
{
const auto py_list = source.cast<py::list>();
auto json_arr = json::array();
for (const auto& p : py_list)
{
json_arr.push_back(cast_from_pyobject_impl(p.cast<py::object>(), parent_path));
std::string path{parent_path + "/" + std::to_string(json_arr.size())};
json_arr.push_back(cast_from_pyobject_impl(p.cast<py::object>(), unserializable_handler_fn, path));
}

return json_arr;
}

if (py::isinstance<py::bool_>(source))
{
return json(py::cast<bool>(source));
}

if (py::isinstance<py::int_>(source))
{
return json(py::cast<long>(source));
}

if (py::isinstance<py::float_>(source))
{
return json(py::cast<double>(source));
}

if (py::isinstance<py::str>(source))
{
return json(py::cast<std::string>(source));
Expand All @@ -198,6 +207,11 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent
path = "/";
}

if (unserializable_handler_fn != nullptr)
{
return unserializable_handler_fn(source, path);
}

error_message << "Object (" << py::str(source).cast<std::string>() << ") of type: " << get_py_type_name(source)
<< " at path: " << path << " is not JSON serializable";

Expand All @@ -208,9 +222,14 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent
// NOLINTEND(modernize-return-braced-init-list)
}

json cast_from_pyobject(const py::object& source, unserializable_handler_fn_t unserializable_handler_fn)
{
return cast_from_pyobject_impl(source, unserializable_handler_fn);
}

json cast_from_pyobject(const py::object& source)
{
return cast_from_pyobject_impl(source);
return cast_from_pyobject_impl(source, nullptr);
}

void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level)
Expand Down
6 changes: 3 additions & 3 deletions python/tests/test_module_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def test_get_module_constructor():
registry.get_module_constructor("SimpleModule", "default")


@pytest.mark.parametrize("config,name", [({"a": "b"}, "serializable"), ({"now": datetime.now()}, "unserializable")])
def test_module_config(config: dict, name: str):
def test_module_config():
"""
Repro test for #461
"""
module_name = f"test_py_mod_config_{name}"
config = {"now": datetime.now()}
module_name = f"test_py_mod_config_unserializable"
registry = mrc.ModuleRegistry

def module_initializer(builder: mrc.Builder):
Expand Down

0 comments on commit 7c500a0

Please sign in to comment.