Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guard against unused arguments in script interfae #4941

Open
RudolfWeeber opened this issue Jun 19, 2024 · 2 comments
Open

Guard against unused arguments in script interfae #4941

RudolfWeeber opened this issue Jun 19, 2024 · 2 comments

Comments

@RudolfWeeber
Copy link
Contributor

Following up to the removed act_on_virtual in the lb thermostat.

A solution we could investigate is to use a wrapper around the VariantMap holding the parameters which keeps track of consumed parameters and throws in the destructor if something wasn't consumed.

VariantMap is a std::unorderd_map<string, Variant>.

Chatgpt came up with the following, which at least convesys the idea.

Prompt: Write a wrapper around std::unordered_map, which keeps trakc on which elements have been accessed and throws in the destructor if not all elements were used
Here is a wrapper around std::unordered_map that keeps track of which elements have been accessed and throws an exception in the destructor if not all elements were used.

#include <unordered_map>
#include <unordered_set>
#include <stdexcept>
#include <iostream>

template <typename Key, typename T>
class AccessTrackingUnorderedMap {
public:
    using MapType = std::unordered_map<Key, T>;

    AccessTrackingUnorderedMap() = default;

    T& operator[](const Key& key) {
        accessed_keys.insert(key);
        return map[key];
    }

    const T& operator[](const Key& key) const {
        accessed_keys.insert(key);
        return map.at(key);
    }

    typename MapType::iterator find(const Key& key) {
        accessed_keys.insert(key);
        return map.find(key);
    }

    typename MapType::const_iterator find(const Key& key) const {
        accessed_keys.insert(key);
        return map.find(key);
    }

    // Other necessary map functions can be added similarly...

    ~AccessTrackingUnorderedMap() {
        if (accessed_keys.size() != map.size()) {
            throw std::runtime_error("Not all elements were accessed!");
        }
    }

    // Functions to insert and manipulate the map
    std::pair<typename MapType::iterator, bool> insert(const typename MapType::value_type& value) {
        return map.insert(value);
    }

    std::pair<typename MapType::iterator, bool> insert(typename MapType::value_type&& value) {
        return map.insert(std::move(value));
    }

    template <typename... Args>
    std::pair<typename MapType::iterator, bool> emplace(Args&&... args) {
        return map.emplace(std::forward<Args>(args)...);
    }

    size_t erase(const Key& key) {
        accessed_keys.erase(key);
        return map.erase(key);
    }

    void clear() {
        accessed_keys.clear();
        map.clear();
    }

    size_t size() const {
        return map.size();
    }

    bool empty() const {
        return map.empty();
    }

    typename MapType::iterator begin() {
        return map.begin();
    }

    typename MapType::const_iterator begin() const {
        return map.begin();
    }

    typename MapType::iterator end() {
        return map.end();
    }

    typename MapType::const_iterator end() const {
        return map.end();
    }

private:
    MapType map;
    mutable std::unordered_set<Key> accessed_keys;
};

// Example usage
int main() {
    try {
        AccessTrackingUnorderedMap<int, std::string> myMap;
        myMap.insert({1, "one"});
        myMap.insert({2, "two"});
        myMap.insert({3, "three"});

        std::cout << myMap[1] << std::endl; // Accessing element with key 1
        std::cout << myMap[2] << std::endl; // Accessing element with key 2
        // Not accessing element with key 3

        // Destructor will throw because not all elements were accessed
    } catch (const std::exception& e) {
        std::cerr << e.what() << std::endl;
    }
    return 0;
}
@jngrad
Copy link
Member

jngrad commented Jun 19, 2024

I'm not totally confident we can implement such behavior in the script interface. It was originally designed to emulate the behavior of a Python def call_method(self, **kwargs) function. Users are free to pass extra unused arguments, just like they are free to assign extra unused class members (e.g. np.random.seed = 42 fiasco).

I can think of a few challenges to the deployment of the proposed solution:

  1. Since C++11, class destructors are noexcept, thus throwing an exception invokes undefined behavior. The ISO C++ FAQ question How can I handle a destructor that fails? provides more details and nuance.
    • your code example triggers -Wterminate, which is always enabled in GCC (I wouldn't recommend disabling it)
    • we heavily rely on try/catch statements in the script interface: we are guaranteed to invoke UB somewhere
  2. There are valid reasons for not reading all input arguments: some optional arguments defined in the Python code are redundant (or irrelevant) depending on which other arguments the user provided.
  3. Such a class introduces a mutable state, which increases the complexity of the code. There are places in the script interface, like in NonBondedInteractionHandle, where we build a temporary variant map containing a subset of the original variant map. Should we copy the accessed_keys too? Or should we copy it by reference and not mark the copied values, so that we know when a value was accessed though the subset?
  4. What happens when we read one argument of the variant map, and based on the value, forward the map to the script interface parent class call_method()? The argument is now marked as having been accessed, although it wasn't necessarily used to initialize data.
  5. What should happen when we call count() or contains() to check for the presence of an argument? Currently it doesn't increment accessed_keys, but if we accidentally use find() to check for the presence of an argument, accessed_keys is incremented. This was common practice in ESPResSo, and while I did replace most of find() by count() in the C++20 PR, I probably missed one or two. Should we mark an argument as accessed, even though we didn't read its value? We cannot really know if the user is going to dereference the iterator, and writing a custom iterator class to detect dereference operations would introduce more complexity.

I've thought about this issue in the past, but couldn't find satisfactory solutions:

  1. Write in a constexpr std::vector<std::unordered_set<std::string>> the list of all possible combinations of arguments and check the variant map satisfies one of them
    • won't work with optional arguments that are assigned default values at the C++ level
    • really easy to forget to update the vector when adding or renaming arguments
    • combinatorial explosion: see how the ParticleHandle class relies on contradicting_arguments_quat to handle the case where the user provides too many valid arguments
    • doesn't scale: one has to do that for every single class method, otherwise the user will assume there was no issue with the keyword arguments if no exception is raised by a method that forgot to implement this check
  2. Use static analysis at the Python level to check keys **kwargs against the docstring
    • unclear how to fetch the docstring: one would have to go through the Python traceback to find the original class, which isn't a cheap operation
    • if it's a derived class, one has to also walk through the class inheritance hierarchy to find the relevant docstring, and this might fail if a docstring says See :class:`HydrodynamicInteraction` for the list of parameters

Can we find examples of Python projects where **kwargs are checked to see if they contain unexpected keys? In my experience, Python projects that do these checks only do it on a small fraction of all **kwargs functions. It would be helpful to see if there are projects that do so systematically, and whether they use tooling, for example to parse docstrings, or wrap the **kwargs in a special dict that implements the behavior proposed in the original post.

@jngrad
Copy link
Member

jngrad commented Jun 20, 2024

Maybe we could adapt the automatic class method generation to something like:

@script_interface_register
class Feature(ScriptInterfaceHelper):
   def foo(self, kT=None, gamma=None):
        kwargs = locals()
        del kwargs["self"]
        self.call_method("foo", **kwargs)

This way, unexpected arguments always raise an exception. Maybe there is a way to express that in a more compact form with a decorator. The danger with such a solution is that locals() returns all variables inside the function. That's good for capturing derived parameters (when their calculation is more convenient in Python than in C++), but temporary variables are also picked up, which can backfire on unsuspecting developers. For example if their type is not convertible to one of the types held in our Variant, an exception is raised. We would also have to deal with None values in C++.

Here is a MWE to check the output of locals() in a class method context:

import unittest
global_var = 5

class P:
    def foo(self, a, b=None):
        c = a
        print(locals())

P().foo(1, b=3)

Output:

{'self': <__main__.P object at 0xa88562fafe0>, 'a': 1, 'b': 3, 'c': 1}

@RudolfWeeber RudolfWeeber added this to the ESPResSo 4.3.0 milestone Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants