Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,41 @@ impl Engine {
v.to_json_str()
}

/// Registers a custom Python function as a Rego extension.
///
/// This allows you to define functions in Python that can be called directly
/// from your Rego policies. The Python function will be called synchronously
/// during policy evaluation.
///
/// Arguments passed from Rego are automatically converted to their corresponding
/// Python types. The return value is converted back to a Rego value.
///
/// * `path`: Full path to the function as it will be used in Rego.
/// * `nargs`: The number of arguments the function expects.
/// * `extension`: The Python function to execute. Must accept exactly `nargs` arguments.
pub fn add_extension(&mut self, path: String, nargs: u8, extension: Py<PyAny>) -> Result<()> {
let func_ref = Arc::new(extension);
let path_clone = path.clone();

let extension_impl = move |args: Vec<Value>| -> Result<Value, anyhow::Error> {
Python::with_gil(|py| {
if !func_ref.bind(py).is_callable() {
return Err(anyhow!("extension must be callable"))
}
let py_args_vec: Result<Vec<PyObject>> =
args.into_iter().map(|arg| to(arg, py)).collect();
let py_args = PyTuple::new(py, py_args_vec?)?;
let py_result = func_ref.call1(py, py_args)
.map_err(|e| anyhow!("extension '{}' raises Python error: {}", path_clone, e))?;
let rego_result = from(&py_result.into_bound(py))?;
Ok(rego_result)
})
};

self.engine
.add_extension(path, nargs, Box::new(extension_impl))
}

/// Enable code coverage
///
/// * `enable`: Whether to enable coverage or not.
Expand Down
207 changes: 207 additions & 0 deletions bindings/python/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import typing
import pytest

import regorus
import sys
Expand Down Expand Up @@ -163,3 +165,208 @@ def run_host_await_example():
print(vm.resume('{"tier":"gold"}'))

run_host_await_example()

def test_extension_execution():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting(a, b) if {
a := data.a
b := data.b
}
""")

def custom_function(arg1, arg2):
return f"{arg1}, {arg2}!"
rego.add_extension("greeting", 2, custom_function)

rego.add_data({"a": "Hello", "b": "World"})
result = rego.eval_rule("data.demo.result")
assert result == "Hello, World!", f"Unexpected result: {result}"

def test_extension_wrong_arity():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting(a, b) if {
a := data.a
b := data.b
}
""")

def custom_function(arg1, arg2):
return f"{arg1}, {arg2}!"

rego.add_extension("greeting", 3, custom_function)
rego.add_data({"a": "Hello", "b": "World"})

with pytest.raises(RuntimeError) as ex:
rego.eval_rule("data.demo.result")

assert "error: incorrect number of parameters supplied to extension" in str(ex.value)

def test_extension_raises_exception():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting(a, b) if {
a := data.a
b := data.b
}
""")

def custom_function(arg1, arg2):
raise RuntimeError("unknown error")

rego.add_extension("greeting", 2, custom_function)
rego.add_data({"a": "Hello", "b": "World"})

with pytest.raises(RuntimeError) as ex:
rego.eval_rule("data.demo.result")

assert "error: extension 'greeting' raise Python error: RuntimeError: unknown error" in str(ex.value)

def test_extension_zero_arg():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting()
""")

def custom_function():
return "Hello, World!"

rego.add_extension("greeting", 0, custom_function)
rego.add_data({"a": "Hello", "b": "World"})

result = rego.eval_rule("data.demo.result")
assert result == "Hello, World!", f"Unexpected result: {result}"

def test_extension_non_callable():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting()
""")

rego.add_extension("greeting", 0, 123)
rego.add_data({"a": "Hello", "b": "World"})

with pytest.raises(RuntimeError) as ex:
rego.eval_rule("data.demo.result")

assert "error: extension must be callable" in str(ex.value)


def test_extension_duplicate():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting()
""")

def custom_function1(arg1, arg2):
return f"{arg1}, {arg2}!"
def custom_function2(arg1, arg2):
return f"{arg1}, {arg2}!"

rego.add_extension("greeting", 0, custom_function1)

with pytest.raises(RuntimeError) as ex:
rego.add_extension("greeting", 0, custom_function2)

assert "extension already added" in str(ex.value)


def test_extension_types():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

i := custom.triple(10)
f := custom.triple(2.5)
b1 := custom.negate(true)
b2 := custom.negate(false)

a := custom.first([true, null, 1])
b := custom.first([null, null, 1])
c := custom.first([null, null, null])

object := custom.modify_object({"a": 1, "b": 2})
list := custom.modify_list([3, 4])
set := custom.modify_set({5, 6})
""")

def triple(n):
return n*3

def negate(b):
return not b

def first(list):
for i in list:
if i is not None:
return i
return None

def modify_object(object):
assert isinstance(object, dict)
return {k: v*2 for k, v in object.items()}

def modify_list(list):
assert isinstance(list, typing.List)
return [x*2 for x in list]

def modify_set(set):
assert isinstance(set, typing.Set)
return {x*2 for x in set}

rego.add_extension("custom.triple", 1, triple)
rego.add_extension("custom.negate", 1, negate)
rego.add_extension("custom.first", 1, first)
rego.add_extension("custom.modify_object", 1, modify_object)
rego.add_extension("custom.modify_list", 1, modify_list)
rego.add_extension("custom.modify_set", 1, modify_set)

i = rego.eval_rule("data.demo.i")
assert i == 30, f"Unexpected result for 'i': {i}"

f = rego.eval_rule("data.demo.f")
assert f == 7.5, f"Unexpected result for 'f': {f}"

b1 = rego.eval_rule("data.demo.b1")
assert b1 == False, f"Unexpected result for 'b1': {b1}"

b2 = rego.eval_rule("data.demo.b2")
assert b2 == True, f"Unexpected result for 'b2': {b2}"

a = rego.eval_rule("data.demo.a")
assert a == True, f"Unexpected result for 'a': {a}"

b = rego.eval_rule("data.demo.b")
assert b == 1, f"Unexpected result for 'b': {b}"

c = rego.eval_rule("data.demo.c")
assert c is None, f"Unexpected result for 'c': {c}"

obj = rego.eval_rule("data.demo.object")
assert obj == {"a": 2, "b": 4}, f"Unexpected object: {obj}"

list = rego.eval_rule("data.demo.list")
assert list == [6, 8], f"Unexpected list: {list}"

set = rego.eval_rule("data.demo.set")
assert set == {10, 12}, f"Unexpected list: {set}"