diff --git a/crytic_compile/utils/natspec.py b/crytic_compile/utils/natspec.py index dad68083..92a98f84 100644 --- a/crytic_compile/utils/natspec.py +++ b/crytic_compile/utils/natspec.py @@ -3,6 +3,70 @@ """ +class DevStateVariable: + """ + Model the dev state variable + """ + + def __init__(self, variable: dict) -> None: + """Init the object + + Args: + method (Dict): Method infos (details, params, returns, custom:*) + """ + self._details: str | None = variable.get("details", None) + if "returns" in variable: + self._returns: dict[str, str] = variable["returns"] + elif "return" in variable: + self._returns: dict[str, str] = {"_0": variable["return"]} + else: + self._returns: dict[str, str] = {} + # Extract custom fields (keys starting with "custom:") + self._custom: dict[str, str] = { + k: v for k, v in variable.items() if k.startswith("custom:") + } + + @property + def details(self) -> str | None: + """Return the state variable details + + Returns: + Optional[str]: state variable details + """ + return self._details + + @property + def variable_returns(self) -> dict[str, str]: + """Return the state variable returns + + Returns: + dict[str, str]: state variable returns + """ + return self._returns + + @property + def custom(self) -> dict[str, str]: + """Return the state variable custom fields + + Returns: + Dict[str, str]: custom field name => value (e.g. "custom:security" => "value") + """ + return self._custom + + def export(self) -> dict: + """Export to a python dict + + Returns: + Dict: Exported dev state variable + """ + result = { + "details": self.details, + "returns": self.variable_returns, + "custom": self.custom, + } + return result + + class UserMethod: """ Model the user method @@ -47,12 +111,17 @@ def __init__(self, method: dict) -> None: """Init the object Args: - method (Dict): Method infos (author, details, params, return, custom:*) + method (Dict): Method infos (author, details, params, returns, custom:*) """ self._author: str | None = method.get("author", None) self._details: str | None = method.get("details", None) self._params: dict[str, str] = method.get("params", {}) - self._return: str | None = method.get("return", None) + if "returns" in method: + self._returns: dict[str, str] = method["returns"] + elif "return" in method: + self._returns: dict[str, str] = {"_0": method["return"]} + else: + self._returns: dict[str, str] = {} # Extract custom fields (keys starting with "custom:") self._custom: dict[str, str] = {k: v for k, v in method.items() if k.startswith("custom:")} @@ -75,13 +144,13 @@ def details(self) -> str | None: return self._details @property - def method_return(self) -> str | None: - """Return the method return + def method_returns(self) -> dict[str, str]: + """Return the method returns Returns: - Optional[str]: method return + dict[str, str]: method returns """ - return self._return + return self._returns @property def params(self) -> dict[str, str]: @@ -111,7 +180,7 @@ def export(self) -> dict: "author": self.author, "details": self.details, "params": self.params, - "return": self.method_return, + "returns": self.method_returns, } # Include custom fields if present result.update(self.custom) @@ -180,6 +249,9 @@ def __init__(self, devdoc: dict): self._methods: dict[str, DevMethod] = { k: DevMethod(item) for k, item in devdoc.get("methods", {}).items() } + self._state_variables: dict[str, DevStateVariable] = { + k: DevStateVariable(item) for k, item in devdoc.get("stateVariables", {}).items() + } self._title: str | None = devdoc.get("title", None) # Extract contract-level custom fields (keys starting with "custom:") self._custom: dict[str, str] = {k: v for k, v in devdoc.items() if k.startswith("custom:")} @@ -211,6 +283,15 @@ def methods(self) -> dict[str, DevMethod]: """ return self._methods + @property + def state_variables(self) -> dict[str, DevStateVariable]: + """Return the dev state variables + + Returns: + Dict[str, DevStateVariable]: state_variable_name => DevStateVariable + """ + return self._state_variables + @property def title(self) -> str | None: """Return the dev title @@ -240,6 +321,7 @@ def export(self) -> dict: "author": self.author, "details": self.details, "title": self.title, + "state_variables": self.state_variables, } # Include custom fields if present result.update(self.custom) diff --git a/tests/test_natspec.py b/tests/test_natspec.py index 72afb6b0..7f977edd 100644 --- a/tests/test_natspec.py +++ b/tests/test_natspec.py @@ -2,7 +2,14 @@ Test NatSpec parsing, including custom fields (@custom:*) """ -from crytic_compile.utils.natspec import DevDoc, DevMethod, Natspec, UserDoc, UserMethod +from crytic_compile.utils.natspec import ( + DevDoc, + DevMethod, + DevStateVariable, + Natspec, + UserDoc, + UserMethod, +) class TestUserMethod: @@ -40,7 +47,7 @@ def test_devmethod_basic_fields(self) -> None: assert method.author == "Test Author" assert method.details == "Method details" assert method.params == {"a": "first param", "b": "second param"} - assert method.method_return == "return value description" + assert method.method_returns == {"_0": "return value description"} def test_devmethod_custom_fields_parsing(self) -> None: """Test DevMethod extracts custom fields""" @@ -80,7 +87,7 @@ def test_devmethod_export_includes_custom(self) -> None: assert exported["author"] == "Test Author" assert exported["details"] == "Details" assert exported["params"] == {"x": "param x"} - assert exported["return"] == "returns something" + assert exported["returns"] == {"_0": "returns something"} assert exported["custom:security"] == "critical" assert exported["custom:audit"] == "passed" @@ -90,9 +97,92 @@ def test_devmethod_empty_method(self) -> None: assert method.author is None assert method.details is None assert method.params == {} - assert method.method_return is None + assert method.method_returns == {} assert method.custom == {} + def test_devmethod_returns_dict(self) -> None: + """Test DevMethod with 'returns' dict field (multiple return values)""" + method_data = { + "details": "Method with multiple returns", + "returns": {"_0": "first value", "_1": "second value"}, + } + method = DevMethod(method_data) + assert method.method_returns == {"_0": "first value", "_1": "second value"} + + def test_devmethod_returns_takes_precedence(self) -> None: + """Test DevMethod prefers 'returns' over 'return' when both present""" + method_data = { + "returns": {"_0": "from returns"}, + "return": "from return", + } + method = DevMethod(method_data) + assert method.method_returns == {"_0": "from returns"} + + +class TestDevStateVariable: + """Tests for DevStateVariable class""" + + def test_state_variable_with_returns_dict(self) -> None: + """Test DevStateVariable with 'returns' dict field""" + var_data = { + "details": "A state variable", + "returns": {"_0": "the stored value"}, + } + var = DevStateVariable(var_data) + assert var.details == "A state variable" + assert var.variable_returns == {"_0": "the stored value"} + + def test_state_variable_with_return_string(self) -> None: + """Test DevStateVariable falls back to 'return' string field""" + var_data = { + "details": "A state variable", + "return": "the stored value", + } + var = DevStateVariable(var_data) + assert var.variable_returns == {"_0": "the stored value"} + + def test_state_variable_returns_takes_precedence(self) -> None: + """Test DevStateVariable prefers 'returns' over 'return' when both present""" + var_data = { + "returns": {"_0": "from returns"}, + "return": "from return", + } + var = DevStateVariable(var_data) + assert var.variable_returns == {"_0": "from returns"} + + def test_state_variable_empty(self) -> None: + """Test DevStateVariable with empty dict""" + var = DevStateVariable({}) + assert var.details is None + assert var.variable_returns == {} + assert var.custom == {} + + def test_state_variable_custom_fields(self) -> None: + """Test DevStateVariable extracts custom fields""" + var_data = { + "details": "A variable", + "custom:security": "sensitive", + "custom:deprecated": "true", + } + var = DevStateVariable(var_data) + assert var.custom == { + "custom:security": "sensitive", + "custom:deprecated": "true", + } + + def test_state_variable_export(self) -> None: + """Test DevStateVariable export""" + var_data = { + "details": "A state variable", + "returns": {"_0": "the value"}, + "custom:audit": "verified", + } + var = DevStateVariable(var_data) + exported = var.export() + assert exported["details"] == "A state variable" + assert exported["returns"] == {"_0": "the value"} + assert exported["custom"] == {"custom:audit": "verified"} + class TestUserDoc: """Tests for UserDoc class"""