diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f8d11bc3..1c110b42d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Support for the DeliverMax field in Payment transactions - Support for the `feature` RPC +- Improved validation for models to also check param types ### Fixed - Allow empty strings for the purpose of removing fields in DIDSet transaction diff --git a/tests/unit/models/test_base_model.py b/tests/unit/models/test_base_model.py index 03c628c97..b22695207 100644 --- a/tests/unit/models/test_base_model.py +++ b/tests/unit/models/test_base_model.py @@ -28,6 +28,7 @@ SignerListSet, TrustSet, TrustSetFlag, + XChainAddAccountCreateAttestation, XChainClaim, ) from xrpl.models.transactions.transaction import Transaction @@ -83,15 +84,44 @@ def test_is_dict_of_model_when_not_true(self): ), ) + def test_bad_type(self): + transaction_dict = { + "account": 1, + "amount": 10, + "destination": 1, + } + with self.assertRaises(XRPLModelException): + Payment(**transaction_dict) + + def test_bad_type_flags(self): + transaction_dict = { + "account": account, + "amount": value, + "destination": destination, + "flags": "1234", # should be an int + } + with self.assertRaises(XRPLModelException): + Payment(**transaction_dict) + + def test_bad_type_enum(self): + path_find_dict = { + "subcommand": "blah", # this is invalid + "source_account": "raoV5dkC66XvGWjSzUhCUuuGM3YFTitMxT", + "destination_account": "rJjusz1VauNA9XaHxJoiwHe38bmQFz1sUV", + "destination_amount": "100", + } + with self.assertRaises(XRPLModelException): + PathFind(**path_find_dict) + class TestFromDict(TestCase): maxDiff = 2000 - def test_from_dict_basic(self): + def test_basic(self): amount = IssuedCurrencyAmount.from_dict(amount_dict) self.assertEqual(amount, IssuedCurrencyAmount(**amount_dict)) - def test_from_dict_recursive_amount(self): + def test_recursive_amount(self): check_create = CheckCreate.from_dict(check_create_dict) expected_dict = { @@ -102,7 +132,7 @@ def test_from_dict_recursive_amount(self): } self.assertEqual(expected_dict, check_create.to_dict()) - def test_from_dict_recursive_currency(self): + def test_recursive_currency(self): xrp = {"currency": "XRP"} issued_currency = { "currency": currency, @@ -122,7 +152,7 @@ def test_from_dict_recursive_currency(self): } self.assertEqual(expected_dict, book_offers.to_dict()) - def test_from_dict_recursive_transaction(self): + def test_recursive_transaction(self): transaction = CheckCreate.from_dict(check_create_dict) sign_dict = {"secret": secret, "transaction": transaction.to_dict()} sign = Sign.from_dict(sign_dict) @@ -139,7 +169,7 @@ def test_from_dict_recursive_transaction(self): del expected_dict["transaction"] self.assertEqual(expected_dict, sign.to_dict()) - def test_from_dict_recursive_transaction_tx_json(self): + def test_recursive_transaction_tx_json(self): transaction = CheckCreate.from_dict(check_create_dict) sign_dict = {"secret": secret, "tx_json": transaction.to_dict()} sign = Sign.from_dict(sign_dict) @@ -155,7 +185,7 @@ def test_from_dict_recursive_transaction_tx_json(self): } self.assertEqual(expected_dict, sign.to_dict()) - def test_from_dict_signer(self): + def test_signer(self): dictionary = { "account": "rpqBNcDpWaqZC2Rksayf8UyG66Fyv2JTQy", "fee": "10", @@ -186,7 +216,7 @@ def test_from_dict_signer(self): actual = SignerListSet.from_dict(dictionary) self.assertEqual(actual, expected) - def test_from_dict_trust_set(self): + def test_trust_set(self): dictionary = { "account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr", "fee": "10", @@ -210,7 +240,7 @@ def test_from_dict_trust_set(self): actual = TrustSet.from_dict(dictionary) self.assertEqual(actual, expected) - def test_from_dict_list_of_lists(self): + def test_list_of_lists(self): path_step_dict = {"account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr"} path_find_dict = { "subcommand": PathFindSubcommand.CREATE, @@ -230,7 +260,7 @@ def test_from_dict_list_of_lists(self): actual = PathFind.from_dict(path_find_dict) self.assertEqual(actual, expected) - def test_from_dict_any(self): + def test_any(self): account_channels_dict = { "account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr", "marker": "something", @@ -239,7 +269,7 @@ def test_from_dict_any(self): actual = AccountChannels.from_dict(account_channels_dict) self.assertEqual(actual, expected) - def test_from_dict_bad_str(self): + def test_bad_str(self): dictionary = { "account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr", "fee": 10, # this should be a str instead ("10") @@ -254,7 +284,7 @@ def test_from_dict_bad_str(self): with self.assertRaises(XRPLModelException): TrustSet.from_dict(dictionary) - def test_from_dict_explicit_none(self): + def test_explicit_none(self): dictionary = { "account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr", "fee": "10", @@ -277,7 +307,7 @@ def test_from_dict_explicit_none(self): actual = TrustSet.from_dict(dictionary) self.assertEqual(actual, expected) - def test_from_dict_with_str_enum_value(self): + def test_with_str_enum_value(self): dictionary = { "method": "account_channels", "account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr", @@ -290,7 +320,7 @@ def test_from_dict_with_str_enum_value(self): actual = AccountChannels.from_dict(dictionary) self.assertEqual(actual, expected) - def test_from_dict_bad_list(self): + def test_bad_list(self): dictionary = { "account": "rpqBNcDpWaqZC2Rksayf8UyG66Fyv2JTQy", "fee": "10", @@ -307,7 +337,7 @@ def test_from_dict_bad_list(self): with self.assertRaises(XRPLModelException): SignerListSet.from_dict(dictionary) - def test_from_dict_multisign(self): + def test_multisign(self): txn_sig1 = ( "F80E201FE295AA08678F8542D8FC18EA18D582A0BD19BE77B9A24479418ADBCF4CAD28E7BD" "96137F88DE7736827C7AC6204FBA8DDADB7394E6D704CD1F4CD609" @@ -379,7 +409,7 @@ def test_from_dict_multisign(self): actual = Request.from_dict(request) self.assertEqual(actual, expected) - def test_from_dict_submit(self): + def test_submit(self): blob = "SOISUSF9SD0839W8U98J98SF" id_val = "submit_786514" request = { @@ -392,49 +422,83 @@ def test_from_dict_submit(self): actual = Request.from_dict(request) self.assertEqual(actual, expected) - # Note: BaseModel.from_xrpl and its overridden methods accept only camelCase or - # PascalCase inputs (i.e. snake_case is not accepted) - def test_request_input_from_xrpl_accepts_camel_case(self): - request = { - "method": "submit", - "tx_json": { - "Account": "rnD6t3JF9RTG4VgNLoc4i44bsQLgJUSi6h", - "transaction_type": "TrustSet", - "Fee": "10", - "Sequence": 17896798, - "Flags": 131072, - "signing_pub_key": "", - "limit_amount": { - "currency": "USD", - "issuer": "rH5gvkKxGHrFAMAACeu9CB3FMu7pQ9jfZm", - "value": "10", - }, + def test_nonexistent_field(self): + tx = { + "account": "rH6ZiHU1PGamME2LvVTxrgvfjQpppWKGmr", + "bad_field": "random", + "flags": 131072, + "limit_amount": { + "currency": "USD", + "issuer": "raoV5dkC66XvGWjSzUhCUuuGM3YFTitMxT", + "value": "100", }, - "fail_hard": False, } + with self.assertRaises(XRPLModelException): + TrustSet.from_dict(tx) + def test_bad_literal(self): + tx = { + "account": issuer, + "xchain_bridge": { + "locking_chain_door": issuer, + "locking_chain_issue": {"currency": "XRP"}, + "issuing_chain_door": issuer, + "issuing_chain_issue": {"currency": "XRP"}, + }, + "public_key": "0342E083EA762D91D621714C394", + "signature": "3044022053B26DAAC9C886192C95", + "other_chain_source": issuer, + "amount": amount_dict, + "attestation_reward_account": issuer, + "attestation_signer_account": issuer, + "was_locking_chain_send": 2, # supposed to be 0 or 1 + "xchain_account_create_count": 12, + "destination": issuer, + "signature_reward": "200", + } with self.assertRaises(XRPLModelException): - Request.from_xrpl(request) + XChainAddAccountCreateAttestation.from_dict(tx) - def test_transaction_input_from_xrpl_accepts_only_camel_case(self): - # verify that Transaction.from_xrpl method does not accept snake_case JSON keys - tx_snake_case_keys = { - "Account": "rnoGkgSpt6AX1nQxZ2qVGx7Fgw6JEcoQas", - "transaction_type": "TrustSet", - "Fee": "10", - "Sequence": 17892983, - "Flags": 131072, - "signing_pub_key": "", - "limit_amount": { - "currency": "USD", - "issuer": "rBPvTKisx7UCGLDtiUZ6mDssXNREuVuL8Y", - "value": "10", + def test_good_literal(self): + tx = { + "account": issuer, + "xchain_bridge": { + "locking_chain_door": issuer, + "locking_chain_issue": {"currency": "XRP"}, + "issuing_chain_door": issuer, + "issuing_chain_issue": {"currency": "XRP"}, }, + "public_key": "0342E083EA762D91D621714C394", + "signature": "3044022053B26DAAC9C886192C95", + "other_chain_source": issuer, + "amount": "100", + "attestation_reward_account": issuer, + "attestation_signer_account": issuer, + "was_locking_chain_send": 1, + "xchain_account_create_count": 12, + "destination": issuer, + "signature_reward": "200", } + expected_dict = { + **tx, + "xchain_bridge": XChainBridge.from_dict(tx["xchain_bridge"]), + } + expected = XChainAddAccountCreateAttestation( + **expected_dict, + ) + self.assertEqual(XChainAddAccountCreateAttestation.from_dict(tx), expected) + + def test_enum(self): + path_find_dict = { + "subcommand": "create", + "source_account": "raoV5dkC66XvGWjSzUhCUuuGM3YFTitMxT", + "destination_account": "rJjusz1VauNA9XaHxJoiwHe38bmQFz1sUV", + "destination_amount": "100", + } + self.assertEqual(PathFind.from_dict(path_find_dict), PathFind(**path_find_dict)) - with self.assertRaises(XRPLModelException): - Transaction.from_xrpl(tx_snake_case_keys) +class TestFromXrpl(TestCase): def test_from_xrpl(self): dirname = os.path.dirname(__file__) full_filename = "x-codec-fixtures.json" @@ -444,14 +508,18 @@ def test_from_xrpl(self): for test in fixtures_json["transactions"]: x_json = test["xjson"] r_json = test["rjson"] - with self.subTest(json=x_json): + with self.subTest(json=x_json, use_json=False): tx = Transaction.from_xrpl(x_json) translated_tx = tx.to_xrpl() self.assertEqual(x_json, translated_tx) - with self.subTest(json=r_json): + with self.subTest(json=r_json, use_json=False): tx = Transaction.from_xrpl(r_json) translated_tx = tx.to_xrpl() self.assertEqual(r_json, translated_tx) + with self.subTest(json=r_json, use_json=True): + tx = Transaction.from_xrpl(json.dumps(r_json)) + translated_tx = tx.to_xrpl() + self.assertEqual(r_json, translated_tx) def test_from_xrpl_signers(self): txn_sig1 = ( @@ -749,3 +817,46 @@ def test_to_from_xrpl_xchain(self): ) self.assertEqual(tx_obj.to_xrpl(), tx_json) self.assertEqual(Transaction.from_xrpl(tx_json), tx_obj) + + def test_request_input_from_xrpl_accepts_camel_case(self): + # Note: BaseModel.from_xrpl and its overridden methods accept only camelCase or + # PascalCase inputs (i.e. snake_case is not accepted) + request = { + "method": "submit", + "tx_json": { + "Account": "rnD6t3JF9RTG4VgNLoc4i44bsQLgJUSi6h", + "transaction_type": "TrustSet", + "Fee": "10", + "Sequence": 17896798, + "Flags": 131072, + "signing_pub_key": "", + "limit_amount": { + "currency": "USD", + "issuer": "rH5gvkKxGHrFAMAACeu9CB3FMu7pQ9jfZm", + "value": "10", + }, + }, + "fail_hard": False, + } + + with self.assertRaises(XRPLModelException): + Request.from_xrpl(request) + + def test_transaction_input_from_xrpl_accepts_only_camel_case(self): + # verify that Transaction.from_xrpl method does not accept snake_case JSON keys + tx_snake_case_keys = { + "Account": "rnoGkgSpt6AX1nQxZ2qVGx7Fgw6JEcoQas", + "transaction_type": "TrustSet", + "Fee": "10", + "Sequence": 17892983, + "Flags": 131072, + "signing_pub_key": "", + "limit_amount": { + "currency": "USD", + "issuer": "rBPvTKisx7UCGLDtiUZ6mDssXNREuVuL8Y", + "value": "10", + }, + } + + with self.assertRaises(XRPLModelException): + Transaction.from_xrpl(tx_snake_case_keys) diff --git a/tests/unit/models/transactions/test_check_cash.py b/tests/unit/models/transactions/test_check_cash.py index 4fadb8537..aa3fca77d 100644 --- a/tests/unit/models/transactions/test_check_cash.py +++ b/tests/unit/models/transactions/test_check_cash.py @@ -6,7 +6,7 @@ _ACCOUNT = "r9LqNeG6qHxjeUocjvVki2XR35weJ9mZgQ" _FEE = "0.00001" _SEQUENCE = 19048 -_CHECK_ID = 19048 +_CHECK_ID = "838766BA2B995C00744175F69A1B11E32C3DBC40E64801A4056FCBD657F57334" _AMOUNT = "300" diff --git a/tests/unit/models/transactions/test_oracle_set.py b/tests/unit/models/transactions/test_oracle_set.py index efd106d41..a694fd1e8 100644 --- a/tests/unit/models/transactions/test_oracle_set.py +++ b/tests/unit/models/transactions/test_oracle_set.py @@ -326,7 +326,7 @@ def test_early_last_update_time_field(self): self.assertEqual( err.exception.args[0], "{'last_update_time': 'LastUpdateTime" - + " must be greater than or equal to Ripple-Epoch 946684800.0 seconds'}", + + " must be greater than or equal to ripple epoch - 946684800 seconds'}", ) # Validity depends on the time of the Last Closed Ledger. This test verifies the diff --git a/tests/unit/models/transactions/test_xchain_claim.py b/tests/unit/models/transactions/test_xchain_claim.py index 82f9e690d..eb3d9efef 100644 --- a/tests/unit/models/transactions/test_xchain_claim.py +++ b/tests/unit/models/transactions/test_xchain_claim.py @@ -68,7 +68,7 @@ def test_successful_claim_destination_tag(self): xchain_bridge=_XRP_BRIDGE, xchain_claim_id=_CLAIM_ID, destination=_DESTINATION, - destination_tag="12345", + destination_tag=12345, amount=_XRP_AMOUNT, ) diff --git a/xrpl/models/base_model.py b/xrpl/models/base_model.py index 84c0416fc..19ae1f911 100644 --- a/xrpl/models/base_model.py +++ b/xrpl/models/base_model.py @@ -291,6 +291,67 @@ def is_valid(self: Self) -> bool: """ return len(self._get_errors()) == 0 + def _check_type( + self: Self, attr: str, value: Any, expected_type: Type[Any] + ) -> Dict[str, str]: + """ + Returns error dictionary if the type of `value` does not match the + `expected_type`. + """ + expected_type_origin = get_origin(expected_type) + if expected_type_origin is Union: + if any( + len(self._check_type(attr, value, expected_type_option)) == 0 + for expected_type_option in get_args(expected_type) + ): + return {} + return {attr: f"{attr} is {type(value)}, expected {expected_type}"} + + # unsure what the problem with mypy is here + if expected_type is Any: # type: ignore[comparison-overlap] + return {} + + if expected_type_origin is list: + # expected a List, received a List + if not isinstance(value, list): + return {attr: f"{attr} is {type(value)}, expected {expected_type}"} + result = {} + for i in range(len(value)): + result.update( + self._check_type( + f"{attr}[{i}]", value[i], get_args(expected_type)[0] + ) + ) + return result + + if expected_type_origin is dict: + return ( + {} + if isinstance(value, dict) + else {attr: f"{attr} is {type(value)}, expected {expected_type}"} + ) + + if isinstance(expected_type, type) and issubclass(expected_type, Enum): + return ( + {} + if value in list(expected_type) + else { + attr: f"{attr} is {value}, expected member of {expected_type} enum" + } + ) + + if expected_type_origin is Literal: + arg = get_args(expected_type) + return {} if value in arg else {attr: f"{attr} is {value}, expected {arg}"} + + if issubclass(expected_type, BaseModel) and isinstance(value, dict): + return {} + + if not isinstance(value, expected_type): + return {attr: f"{attr} is {type(value)}, expected {expected_type}"} + + return {} + def _get_errors(self: Self) -> Dict[str, str]: """ Extended in subclasses to define custom validation logic. @@ -298,11 +359,14 @@ def _get_errors(self: Self) -> Dict[str, str]: Returns: Dictionary of any errors found on self. """ - return { - attr: f"{attr} is not set" - for attr, value in self.__dict__.items() - if value is REQUIRED - } + class_types = get_type_hints(self.__class__) + result: Dict[str, str] = {} + for attr, value in self.__dict__.items(): + if value is REQUIRED: + result[attr] = f"{attr} is not set" + else: + result.update(self._check_type(attr, value, class_types[attr])) + return result def to_dict(self: Self) -> Dict[str, Any]: """ @@ -339,6 +403,6 @@ def __eq__(self: Self, other: object) -> bool: return isinstance(other, BaseModel) and self.to_dict() == other.to_dict() def __repr__(self: Self) -> str: - """Returns a string representation of a BaseModel object""" + """Returns a string representation of a BaseModel object.""" repr_items = [f"{key}={repr(value)}" for key, value in self.to_dict().items()] return f"{type(self).__name__}({repr_items})" diff --git a/xrpl/models/requests/ledger_entry.py b/xrpl/models/requests/ledger_entry.py index 974c7abbf..ca4badbd2 100644 --- a/xrpl/models/requests/ledger_entry.py +++ b/xrpl/models/requests/ledger_entry.py @@ -256,9 +256,9 @@ class LedgerEntry(Request, LookupByLedgerRequest): ticket: Optional[Union[str, Ticket]] = None bridge_account: Optional[str] = None bridge: Optional[XChainBridge] = None - xchain_claim_id: Optional[Union[str, XChainClaimID]] = None + xchain_claim_id: Optional[Union[int, str, XChainClaimID]] = None xchain_create_account_claim_id: Optional[ - Union[str, XChainCreateAccountClaimID] + Union[int, str, XChainCreateAccountClaimID] ] = None binary: bool = False diff --git a/xrpl/models/transactions/oracle_set.py b/xrpl/models/transactions/oracle_set.py index 17b7f0147..f2840ff13 100644 --- a/xrpl/models/transactions/oracle_set.py +++ b/xrpl/models/transactions/oracle_set.py @@ -20,10 +20,10 @@ MAX_ORACLE_SYMBOL_CLASS = 16 # epoch offset must equal 946684800 seconds. It represents the diff between the -# genesis of Unix time and Ripple-Epoch time -EPOCH_OFFSET = ( - datetime.datetime(2000, 1, 1) - datetime.datetime(1970, 1, 1) -).total_seconds() +# genesis of Unix time and ripple epoch time +EPOCH_OFFSET = int( + (datetime.datetime(2000, 1, 1) - datetime.datetime(1970, 1, 1)).total_seconds() +) @require_kwargs_on_init @@ -146,7 +146,7 @@ def _get_errors(self: Self) -> Dict[str, str]: if self.last_update_time < EPOCH_OFFSET: errors["last_update_time"] = ( "LastUpdateTime must be greater than or equal" - f" to Ripple-Epoch {EPOCH_OFFSET} seconds" + f" to ripple epoch - {EPOCH_OFFSET} seconds" ) return errors diff --git a/xrpl/models/transactions/pseudo_transactions/enable_amendment.py b/xrpl/models/transactions/pseudo_transactions/enable_amendment.py index 180ed56de..5b03cf659 100644 --- a/xrpl/models/transactions/pseudo_transactions/enable_amendment.py +++ b/xrpl/models/transactions/pseudo_transactions/enable_amendment.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from xrpl.models.flags import FlagInterface from xrpl.models.required import REQUIRED @@ -101,7 +101,7 @@ class EnableAmendment(PseudoTransaction): init=False, ) - flags: Union[int, List[int]] = 0 + flags: Union[Dict[str, bool], int, List[int]] = 0 """ The Flags value of the EnableAmendment pseudo-transaction indicates the status of the amendment at the time of the ledger including the pseudo-transaction. diff --git a/xrpl/models/transactions/transaction.py b/xrpl/models/transactions/transaction.py index ea25aab8c..c582222e5 100644 --- a/xrpl/models/transactions/transaction.py +++ b/xrpl/models/transactions/transaction.py @@ -251,9 +251,6 @@ class Transaction(BaseModel): """The network id of the transaction.""" def _get_errors(self: Self) -> Dict[str, str]: - # import must be here to avoid circular dependencies - from xrpl.wallet.main import Wallet - errors = super()._get_errors() if self.ticket_sequence is not None and ( (self.sequence is not None and self.sequence != 0) @@ -264,9 +261,6 @@ def _get_errors(self: Self) -> Dict[str, str]: ] = """If ticket_sequence is provided, account_txn_id must be None and sequence must be None or 0""" - if isinstance(self.account, Wallet): - errors["account"] = "Must pass in `wallet.address`, not `wallet`." - return errors def to_dict(self: Self) -> Dict[str, Any]: @@ -371,6 +365,9 @@ def has_flag(self: Self, flag: int) -> bool: Returns: Whether the transaction has the given flag value set. + + Raises: + XRPLModelException: if `self.flags` is invalid. """ if isinstance(self.flags, int): return self.flags & flag != 0 @@ -379,8 +376,10 @@ def has_flag(self: Self, flag: int) -> bool: tx_type=self.transaction_type, tx_flags=self.flags, ) - else: # is List[int] + elif isinstance(self.flags, list): return flag in self.flags + else: + raise XRPLModelException("self.flags is not an int, dict, or list") def is_signed(self: Self) -> bool: """ diff --git a/xrpl/utils/xrp_conversions.py b/xrpl/utils/xrp_conversions.py index d8aa2cf3c..5eaaa4829 100644 --- a/xrpl/utils/xrp_conversions.py +++ b/xrpl/utils/xrp_conversions.py @@ -1,4 +1,5 @@ """Conversions between XRP drops and native number types.""" + import re from decimal import Decimal, InvalidOperation, localcontext from typing import Pattern, Union @@ -35,7 +36,7 @@ def xrp_to_drops(xrp: Union[int, float, Decimal]) -> str: TypeError: if ``xrp`` is given as a string XRPRangeException: if the given amount of XRP is invalid """ - if type(xrp) == str: # type: ignore + if isinstance(xrp, str): # This protects people from passing drops to this function and getting # a million times as many drops back. raise TypeError( @@ -83,7 +84,7 @@ def drops_to_xrp(drops: str) -> Decimal: TypeError: if ``drops`` not given as a string XRPRangeException: if the given number of drops is invalid """ - if type(drops) != str: + if not isinstance(drops, str): raise TypeError(f"Drops must be provided as string (got {type(drops)})") drops = drops.strip() with localcontext(DROPS_DECIMAL_CONTEXT):