diff --git a/hathor/conf/settings.py b/hathor/conf/settings.py index 6a0ca8e47..aab40cc1b 100644 --- a/hathor/conf/settings.py +++ b/hathor/conf/settings.py @@ -19,7 +19,7 @@ from hathor.checkpoint import Checkpoint from hathor.consensus.consensus_settings import ConsensusSettings, PowSettings from hathor.feature_activation.settings import Settings as FeatureActivationSettings -from hathorlib.conf.settings import HathorSettings as LibSettings +from hathorlib.conf.settings import FeatureSetting, HathorSettings as LibSettings DECIMAL_PLACES = 2 @@ -32,6 +32,13 @@ class HathorSettings(LibSettings): model_config = ConfigDict(extra='forbid') + # Fee rate settings for shielded outputs + FEE_PER_AMOUNT_SHIELDED_OUTPUT: int = 1 + FEE_PER_FULL_SHIELDED_OUTPUT: int = 2 + + # Used to enable shielded transactions. + ENABLE_SHIELDED_TRANSACTIONS: FeatureSetting = FeatureSetting.DISABLED + # Block checkpoints CHECKPOINTS: list[Checkpoint] = [] diff --git a/hathor/consensus/consensus.py b/hathor/consensus/consensus.py index 44c76d141..b022aa282 100644 --- a/hathor/consensus/consensus.py +++ b/hathor/consensus/consensus.py @@ -25,6 +25,7 @@ from hathor.consensus.transaction_consensus import TransactionConsensusAlgorithmFactory from hathor.execution_manager import non_critical_code from hathor.feature_activation.feature import Feature +from hathor.feature_activation.utils import Features from hathor.nanocontracts.exception import NCInvalidSignature from hathor.nanocontracts.execution import NCBlockExecutor, NCConsensusBlockExecutor from hathor.profiler import get_cpu_profiler @@ -456,6 +457,9 @@ def _feature_activation_rules(self, tx: Transaction, new_best_block: Block) -> b case Feature.OPCODES_V2: if not self._opcodes_v2_activation_rule(tx, new_best_block): return False + case Feature.SHIELDED_TRANSACTIONS: + if not self._shielded_activation_rule(tx, is_active): + return False case ( Feature.INCREASE_MAX_MERKLE_PATH_LENGTH | Feature.NOP_FEATURE_1 @@ -506,6 +510,16 @@ def _fee_tokens_activation_rule(self, tx: Transaction, is_active: bool) -> bool: return True + def _shielded_activation_rule(self, tx: Transaction, is_active: bool) -> bool: + """Check whether a tx became invalid because the reorg changed the shielded feature activation state.""" + if is_active: + return True + + if tx.has_shielded_outputs(): + return False + + return True + def _checkdatasig_count_rule(self, tx: Transaction) -> bool: """Check whether a tx became invalid because of the count checkdatasig feature.""" from hathor.verification.vertex_verifier import VertexVerifier @@ -530,7 +544,12 @@ def _opcodes_v2_activation_rule(self, tx: Transaction, new_best_block: Block) -> # We check all txs regardless of the feature state, because this rule # already prohibited mempool txs before the block feature activation. - params = VerificationParams.default_for_mempool(best_block=new_best_block) + features = Features.from_vertex( + settings=self._settings, + feature_service=self.feature_service, + vertex=new_best_block, + ) + params = VerificationParams.default_for_mempool(best_block=new_best_block, features=features) # Any exception in the inputs verification will be considered # a fail and the tx will be removed from the mempool. diff --git a/hathor/dag_builder/default_filler.py b/hathor/dag_builder/default_filler.py index b104d0a7c..88dc68c5e 100644 --- a/hathor/dag_builder/default_filler.py +++ b/hathor/dag_builder/default_filler.py @@ -128,6 +128,20 @@ def calculate_balance(self, node: DAGNode) -> dict[str, int]: return balance + def _account_for_shielded_fee(self, node: DAGNode) -> None: + """Subtract shielded output fees from the node's HTR balance.""" + fee = 0 + for txout in node.outputs: + if txout is None: + continue + _, _, attrs = txout + if attrs.get('full-shielded'): + fee += self._settings.FEE_PER_FULL_SHIELDED_OUTPUT + elif attrs.get('shielded'): + fee += self._settings.FEE_PER_AMOUNT_SHIELDED_OUTPUT + if fee > 0: + node.balances['HTR'] = node.balances.get('HTR', 0) - fee + def balance_node_inputs_and_outputs(self, node: DAGNode) -> None: """Balance the inputs and outputs of a node.""" balance = self.calculate_balance(node) @@ -222,6 +236,7 @@ def run(self) -> None: continue self.fill_parents(node) + self._account_for_shielded_fee(node) self.balance_node_inputs_and_outputs(node) case DAGNodeType.OnChainBlueprint: diff --git a/hathor/dag_builder/tokenizer.py b/hathor/dag_builder/tokenizer.py index 19dbbed55..7a4277d70 100644 --- a/hathor/dag_builder/tokenizer.py +++ b/hathor/dag_builder/tokenizer.py @@ -214,7 +214,13 @@ def tokenize(content: str) -> Iterator[Token]: index = int(key[4:-1]) amount = int(parts[2]) token = parts[3] - attrs = parts[4:] + raw_attrs = parts[4:] + attrs: dict[str, str | int] = {} + for a in raw_attrs: + if a.startswith('[') and a.endswith(']'): + attrs[a[1:-1]] = 1 + else: + attrs[a] = 1 yield (TokenType.OUTPUT, (name, index, amount, token, attrs)) else: value = ' '.join(parts[2:]) diff --git a/hathor/dag_builder/vertex_exporter.py b/hathor/dag_builder/vertex_exporter.py index fb26f93fd..cc1fb7d4d 100644 --- a/hathor/dag_builder/vertex_exporter.py +++ b/hathor/dag_builder/vertex_exporter.py @@ -142,13 +142,18 @@ def _create_vertex_txout( *, token_creation: bool = False ) -> tuple[list[bytes], list[TxOutput]]: - """Create TxOutput objects for a node.""" + """Create TxOutput objects for a node. Shielded outputs are skipped here.""" tokens: list[bytes] = [] outputs: list[TxOutput] = [] for txout in node.outputs: assert txout is not None amount, token_name, attrs = txout + + # Skip shielded outputs — they are handled by add_shielded_outputs_header_if_needed + if attrs.get('shielded') or attrs.get('full-shielded'): + continue + if token_name == 'HTR': index = 0 elif token_creation: @@ -330,6 +335,8 @@ def add_headers_if_needed(self, node: DAGNode, vertex: BaseTransaction) -> None: """Add the configured headers.""" self.add_nano_header_if_needed(node, vertex) self.add_fee_header_if_needed(node, vertex) + self.add_shielded_outputs_header_if_needed(node, vertex) + self._add_or_augment_shielded_fee(node, vertex) def add_nano_header_if_needed(self, node: DAGNode, vertex: BaseTransaction) -> None: if 'nc_id' not in node.attrs: @@ -456,6 +463,65 @@ def add_fee_header_if_needed(self, node: DAGNode, vertex: BaseTransaction) -> No ) vertex.headers.append(fee_header) + def _add_or_augment_shielded_fee(self, node: DAGNode, vertex: BaseTransaction) -> None: + """Add or augment a FeeHeader with the shielded output fee.""" + if not isinstance(vertex, Transaction): + return + + from hathor.verification.transaction_verifier import TransactionVerifier + shielded_fee = TransactionVerifier.calculate_shielded_fee(self._settings, vertex) + if shielded_fee == 0: + return + + # Look for an existing FeeHeader + existing_fee_header: FeeHeader | None = None + for header in vertex.headers: + if isinstance(header, FeeHeader): + existing_fee_header = header + break + + if existing_fee_header is not None: + # Augment the existing FeeHeader: find HTR entry and add shielded fee + new_fees: list[FeeHeaderEntry] = [] + found_htr = False + for entry in existing_fee_header.fees: + if entry.token_index == 0 and not found_htr: + # Augment the HTR fee entry + new_fees.append(FeeHeaderEntry(token_index=0, amount=entry.amount + shielded_fee)) + found_htr = True + else: + new_fees.append(entry) + if not found_htr: + new_fees.append(FeeHeaderEntry(token_index=0, amount=shielded_fee)) + existing_fee_header.fees = new_fees + else: + # Create a new FeeHeader with just the shielded fee + fee_header = FeeHeader( + settings=vertex._settings, + tx=vertex, + fees=[FeeHeaderEntry(token_index=0, amount=shielded_fee)], + ) + vertex.headers.append(fee_header) + + def add_shielded_outputs_header_if_needed(self, node: DAGNode, vertex: BaseTransaction) -> None: + """Collect outputs with [shielded] or [full-shielded] attrs into a ShieldedOutputsHeader.""" + # TODO: For each output with [shielded] or [full-shielded] attrs, generate an + # ephemeral keypair for ECDH recovery, derive Pedersen commitments and range proofs + # using hathor.crypto.shielded, create AmountShieldedOutput/FullShieldedOutput, and + # attach as ShieldedOutputsHeader. For full-shielded: also create asset commitments + # and surjection proofs. + return + + def _get_recipient_pubkey_from_script(self, script: bytes) -> bytes | None: + """Extract the recipient's compressed public key from a P2PKH script. + + Looks up the address in all wallets to find the corresponding public key. + Returns None if the public key cannot be determined. + """ + # TODO: Parse P2PKH script to get address, look up in wallets, extract + # compressed public key using extract_key_bytes from hathor.crypto.shielded.ecdh. + return None + def create_vertex_on_chain_blueprint(self, node: DAGNode) -> OnChainBlueprint: """Create an OnChainBlueprint given a node.""" block_parents, txs_parents = self._create_vertex_parents(node) diff --git a/hathor/feature_activation/feature.py b/hathor/feature_activation/feature.py index 480ef5685..273af1abf 100644 --- a/hathor/feature_activation/feature.py +++ b/hathor/feature_activation/feature.py @@ -33,3 +33,4 @@ class Feature(StrEnum): NANO_CONTRACTS = 'NANO_CONTRACTS' FEE_TOKENS = 'FEE_TOKENS' OPCODES_V2 = 'OPCODES_V2' + SHIELDED_TRANSACTIONS = 'SHIELDED_TRANSACTIONS' diff --git a/hathor/feature_activation/utils.py b/hathor/feature_activation/utils.py index 47a01235b..d9817c9b1 100644 --- a/hathor/feature_activation/utils.py +++ b/hathor/feature_activation/utils.py @@ -36,6 +36,7 @@ class Features: nanocontracts: bool fee_tokens: bool opcodes_version: OpcodesVersion + shielded_transactions: bool @staticmethod def from_vertex(*, settings: HathorSettings, feature_service: FeatureService, vertex: Vertex) -> Features: @@ -47,6 +48,7 @@ def from_vertex(*, settings: HathorSettings, feature_service: FeatureService, ve Feature.NANO_CONTRACTS: settings.ENABLE_NANO_CONTRACTS, Feature.FEE_TOKENS: settings.ENABLE_FEE_BASED_TOKENS, Feature.OPCODES_V2: settings.ENABLE_OPCODES_V2, + Feature.SHIELDED_TRANSACTIONS: settings.ENABLE_SHIELDED_TRANSACTIONS, } feature_is_active: dict[Feature, bool] = { @@ -61,6 +63,7 @@ def from_vertex(*, settings: HathorSettings, feature_service: FeatureService, ve nanocontracts=feature_is_active[Feature.NANO_CONTRACTS], fee_tokens=feature_is_active[Feature.FEE_TOKENS], opcodes_version=opcodes_version, + shielded_transactions=feature_is_active[Feature.SHIELDED_TRANSACTIONS], ) diff --git a/hathor/indexes/utxo_index.py b/hathor/indexes/utxo_index.py index 8b5dcde93..96431d9d0 100644 --- a/hathor/indexes/utxo_index.py +++ b/hathor/indexes/utxo_index.py @@ -143,7 +143,12 @@ def _update_executed(self, tx: BaseTransaction) -> None: # remove all inputs for tx_input in tx.inputs: spent_tx = tx.get_spent_tx(tx_input) - spent_tx_output = spent_tx.outputs[tx_input.index] + # Use resolve_spent_output for shielded-aware lookup + resolved = spent_tx.resolve_spent_output(tx_input.index) + if not isinstance(resolved, TxOutput): + # Shielded outputs don't have public value/token for the UTXO index + continue + spent_tx_output = resolved log_it = log.new(tx_id=spent_tx.hash_hex, index=tx_input.index) if _should_skip_output(spent_tx_output): log_it.debug('ignore input') @@ -184,7 +189,12 @@ def _update_voided(self, tx: BaseTransaction) -> None: # re-add inputs that aren't voided for tx_input in tx.inputs: spent_tx = tx.get_spent_tx(tx_input) - spent_tx_output = spent_tx.outputs[tx_input.index] + # Use resolve_spent_output for shielded-aware lookup + resolved = spent_tx.resolve_spent_output(tx_input.index) + if not isinstance(resolved, TxOutput): + # Shielded outputs don't have public value/token for the UTXO index + continue + spent_tx_output = resolved log_it = log.new(tx_id=spent_tx.hash_hex, index=tx_input.index) if _should_skip_output(spent_tx_output): log_it.debug('ignore input') diff --git a/hathor/nanocontracts/vertex_data.py b/hathor/nanocontracts/vertex_data.py index 22305955b..59f404551 100644 --- a/hathor/nanocontracts/vertex_data.py +++ b/hathor/nanocontracts/vertex_data.py @@ -29,7 +29,11 @@ def _get_txin_output(vertex: BaseTransaction, txin: TxInput) -> TxOutput | None: - """Return the output that txin points to.""" + """Return the output that txin points to. + + Returns None for shielded outputs (they don't have TxOutput fields) + or when storage is unavailable. + """ from hathor.transaction.storage.exceptions import TransactionDoesNotExist if vertex.storage is None: @@ -40,10 +44,18 @@ def _get_txin_output(vertex: BaseTransaction, txin: TxInput) -> TxOutput | None: except TransactionDoesNotExist: assert False, f'missing dependency: {txin.tx_id.hex()}' - assert len(vertex2.outputs) > txin.index, 'invalid output index' + # Use resolve_spent_output for shielded-aware lookup + try: + resolved = vertex2.resolve_spent_output(txin.index) + except IndexError: + return None + + # Only return TxOutput; shielded outputs lack value/token_data for TxOutputData + from hathor.transaction import TxOutput as _TxOutput + if not isinstance(resolved, _TxOutput): + return None - txin_output = vertex2.outputs[txin.index] - return txin_output + return resolved @dataclass(frozen=True, slots=True, kw_only=True) diff --git a/hathor/p2p/sync_v2/transaction_streaming_client.py b/hathor/p2p/sync_v2/transaction_streaming_client.py index 92402cd2d..3421f867b 100644 --- a/hathor/p2p/sync_v2/transaction_streaming_client.py +++ b/hathor/p2p/sync_v2/transaction_streaming_client.py @@ -54,6 +54,9 @@ def __init__(self, # it will be correctly enabled when doing a full validation anyway. # We can also set the `nc_block_root_id` to `None` because we only call `verify_basic`, # which doesn't need it. + # XXX: Default to shielded_transactions=False since shielded txs cannot exist + # before the feature is activated. The correct value will be computed when doing + # a full validation anyway. self.verification_params = VerificationParams( nc_block_root_id=None, features=Features( @@ -61,6 +64,7 @@ def __init__(self, nanocontracts=False, fee_tokens=False, opcodes_version=OpcodesVersion.V1, + shielded_transactions=False, ) ) diff --git a/hathor/transaction/base_transaction.py b/hathor/transaction/base_transaction.py index 37392b12b..57647a809 100644 --- a/hathor/transaction/base_transaction.py +++ b/hathor/transaction/base_transaction.py @@ -53,6 +53,7 @@ from hathor.conf.settings import HathorSettings from hathor.transaction import Transaction + from hathor.transaction.shielded_tx_output import OutputMode, ShieldedOutput from hathor.transaction.storage import TransactionStorage # noqa: F401 from hathor.transaction.vertex_children import VertexChildren @@ -269,6 +270,10 @@ def has_fees(self) -> bool: """Return whether this transaction has a fee header.""" return False + def has_shielded_outputs(self) -> bool: + """Return whether this vertex has shielded outputs.""" + return False + def get_fields_from_struct(self, struct_bytes: bytes, *, verbose: VerboseCallback = None) -> bytes: """ Gets all common fields for a Transaction and a Block from a buffer. @@ -298,7 +303,7 @@ def get_header_from_bytes(self, buf: bytes, *, verbose: VerboseCallback = None) def get_maximum_number_of_headers(self) -> int: """Return the maximum number of headers for this vertex.""" - return 2 + return 3 @classmethod @abstractmethod @@ -353,6 +358,29 @@ def hash_hex(self) -> str: else: return '' + @property + def shielded_outputs(self) -> list['ShieldedOutput']: + """Return the list of shielded outputs. Empty for non-Transaction vertices.""" + return [] + + def resolve_spent_output(self, index: int) -> 'TxOutput | ShieldedOutput': + """Resolve an output by index, checking both transparent and shielded outputs. + + CONS-017: 3-way lookup: transparent outputs first, then shielded, then raise. + """ + if index < len(self.outputs): + return self.outputs[index] + shielded_idx = index - len(self.outputs) + shielded = self.shielded_outputs + if shielded_idx < len(shielded): + return shielded[shielded_idx] + raise IndexError(f'output index {index} out of range (transparent={len(self.outputs)}, ' + f'shielded={len(shielded)})') + + def is_shielded_output(self, index: int) -> bool: + """Return True if `index` refers to a shielded output (i.e. index >= len(self.outputs)).""" + return index >= len(self.outputs) and index < len(self.outputs) + len(self.shielded_outputs) + @property def sum_outputs(self) -> int: """Sum of the value of the outputs""" @@ -513,8 +541,12 @@ def add_address_from_output(output: 'TxOutput') -> None: for txin in self.inputs: tx2 = self.storage.get_transaction(txin.tx_id) - txout = tx2.outputs[txin.index] - add_address_from_output(txout) + # CONS-017: use resolve_spent_output for shielded-aware lookup + resolved = tx2.resolve_spent_output(txin.index) + from hathor.transaction.scripts import parse_address_script as _parse + script_type = _parse(resolved.script) + if script_type: + addresses.add(script_type.address) for txout in self.outputs: add_address_from_output(txout) @@ -862,11 +894,22 @@ def serialize_output(tx: BaseTransaction, tx_out: TxOutput) -> dict[str, Any]: for index, tx_in in enumerate(self.inputs): tx2 = self.storage.get_transaction(tx_in.tx_id) - tx2_out = tx2.outputs[tx_in.index] - output = serialize_output(tx2, tx2_out) - output['tx_id'] = tx2.hash_hex - output['index'] = tx_in.index - ret['inputs'].append(output) + # CONS-019: use resolve_spent_output for shielded-aware lookup + if tx2.is_shielded_output(tx_in.index): + shielded_out = tx2.resolve_spent_output(tx_in.index) + output_data: dict[str, Any] = { + 'type': 'shielded', + 'commitment': shielded_out.commitment.hex(), # type: ignore[union-attr] + 'script': shielded_out.script.hex(), + 'tx_id': tx2.hash_hex, + 'index': tx_in.index, + } + else: + tx2_out = tx2.outputs[tx_in.index] + output_data = serialize_output(tx2, tx2_out) + output_data['tx_id'] = tx2.hash_hex + output_data['index'] = tx_in.index + ret['inputs'].append(output_data) for index, tx_out in enumerate(self.outputs): spent_by = meta.get_output_spent_by(index) @@ -1100,6 +1143,12 @@ def get_token_index(self) -> int: """The token uid index in the list""" return self.token_data & self.TOKEN_INDEX_MASK + @staticmethod + def mode() -> OutputMode: + """Return the output mode (TRANSPARENT for standard TxOutput).""" + from hathor.transaction.shielded_tx_output import OutputMode as _OutputMode + return _OutputMode.TRANSPARENT + def is_token_authority(self) -> bool: """Whether this is a token authority output""" return (self.token_data & self.TOKEN_AUTHORITY_MASK) > 0 diff --git a/hathor/transaction/exceptions.py b/hathor/transaction/exceptions.py index 704dc7fe0..1231dbfa1 100644 --- a/hathor/transaction/exceptions.py +++ b/hathor/transaction/exceptions.py @@ -278,3 +278,31 @@ class InvalidFeeAmount(InvalidFeeHeader): class TokenNotFound(TxValidationError): """Token not found.""" + + +class InvalidRangeProofError(TxValidationError): + """Range proof is invalid.""" + + +class InvalidSurjectionProofError(TxValidationError): + """Surjection proof is invalid.""" + + +class ShieldedBalanceMismatchError(TxValidationError): + """Shielded balance equation does not hold.""" + + +class TrivialCommitmentError(TxValidationError): + """Rule 4: All transparent inputs require >= 2 shielded outputs.""" + + +class ShieldedAuthorityError(TxValidationError): + """Rule 7: Authority outputs cannot be shielded.""" + + +class ShieldedMintMeltForbiddenError(TxValidationError): + """Mint/melt operations are not allowed in transactions with shielded outputs.""" + + +class InvalidShieldedOutputError(TxValidationError): + """Generic invalid shielded output error.""" diff --git a/hathor/transaction/headers/__init__.py b/hathor/transaction/headers/__init__.py index 64efadf57..364889689 100644 --- a/hathor/transaction/headers/__init__.py +++ b/hathor/transaction/headers/__init__.py @@ -15,6 +15,7 @@ from hathor.transaction.headers.base import VertexBaseHeader from hathor.transaction.headers.fee_header import FeeHeader from hathor.transaction.headers.nano_header import NanoHeader +from hathor.transaction.headers.shielded_outputs_header import ShieldedOutputsHeader from hathor.transaction.headers.types import VertexHeaderId __all__ = [ @@ -22,4 +23,5 @@ 'VertexHeaderId', 'NanoHeader', 'FeeHeader', + 'ShieldedOutputsHeader', ] diff --git a/hathor/transaction/headers/base.py b/hathor/transaction/headers/base.py index aba002ad9..3732bc367 100644 --- a/hathor/transaction/headers/base.py +++ b/hathor/transaction/headers/base.py @@ -24,6 +24,12 @@ class VertexBaseHeader(ABC): + @classmethod + @abstractmethod + def get_header_id(cls) -> bytes: + """Return the 1-byte header ID for this header type.""" + raise NotImplementedError + @classmethod @abstractmethod def deserialize( diff --git a/hathor/transaction/headers/fee_header.py b/hathor/transaction/headers/fee_header.py index d2bca54b5..5bca863a1 100644 --- a/hathor/transaction/headers/fee_header.py +++ b/hathor/transaction/headers/fee_header.py @@ -49,6 +49,10 @@ class FeeEntry: @dataclass(slots=True, kw_only=True) class FeeHeader(VertexBaseHeader): + @classmethod + def get_header_id(cls) -> bytes: + return VertexHeaderId.FEE_HEADER.value + # transaction that contains the fee header tx: 'Transaction' # list of tokens and amounts that will be used to pay fees in the transaction diff --git a/hathor/transaction/headers/nano_header.py b/hathor/transaction/headers/nano_header.py index cf49b94dc..ac7525d80 100644 --- a/hathor/transaction/headers/nano_header.py +++ b/hathor/transaction/headers/nano_header.py @@ -99,6 +99,10 @@ def _validate_authorities(self, token_uid: TokenUid) -> None: @dataclass(slots=True, kw_only=True) class NanoHeader(VertexBaseHeader): + @classmethod + def get_header_id(cls) -> bytes: + return VertexHeaderId.NANO_HEADER.value + tx: Transaction # Sequence number for the caller. diff --git a/hathor/transaction/headers/shielded_outputs_header.py b/hathor/transaction/headers/shielded_outputs_header.py new file mode 100644 index 000000000..3b290680d --- /dev/null +++ b/hathor/transaction/headers/shielded_outputs_header.py @@ -0,0 +1,119 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from hathor.transaction.headers.base import VertexBaseHeader +from hathor.transaction.headers.types import VertexHeaderId +from hathor.transaction.shielded_tx_output import ( + MAX_SHIELDED_OUTPUTS, + ShieldedOutput, + deserialize_shielded_output, + get_sighash_bytes as output_sighash_bytes, + serialize_shielded_output, +) +from hathor.transaction.util import VerboseCallback, int_to_bytes + +if TYPE_CHECKING: + from hathor.transaction.base_transaction import BaseTransaction + from hathor.transaction.transaction import Transaction + + +@dataclass(slots=True, kw_only=True) +class ShieldedOutputsHeader(VertexBaseHeader): + @classmethod + def get_header_id(cls) -> bytes: + return VertexHeaderId.SHIELDED_OUTPUTS_HEADER.value + + tx: Transaction + shielded_outputs: list[ShieldedOutput] = field(default_factory=list) + + @classmethod + def deserialize( + cls, + tx: BaseTransaction, + buf: bytes, + *, + verbose: VerboseCallback = None, + ) -> tuple[ShieldedOutputsHeader, bytes]: + """Deserialize: header_id(1) | num_outputs(1) | outputs...""" + from hathor.transaction.exceptions import InvalidShieldedOutputError + from hathor.transaction.transaction import Transaction + + if not isinstance(tx, Transaction): + raise InvalidShieldedOutputError( + f'shielded outputs header requires a Transaction, got {type(tx).__name__}' + ) + + try: + offset = 0 + header_id = buf[offset:offset + 1] + offset += 1 + if verbose: + verbose('header_id', header_id) + assert header_id == VertexHeaderId.SHIELDED_OUTPUTS_HEADER.value + + num_outputs = buf[offset] + offset += 1 + if verbose: + verbose('num_shielded_outputs', num_outputs) + + if num_outputs < 1: + raise InvalidShieldedOutputError('shielded outputs header must contain at least 1 output') + if num_outputs > MAX_SHIELDED_OUTPUTS: + raise InvalidShieldedOutputError( + f'too many shielded outputs: {num_outputs} exceeds maximum {MAX_SHIELDED_OUTPUTS}' + ) + + shielded_outputs: list[ShieldedOutput] = [] + remaining = buf[offset:] + for _ in range(num_outputs): + output, remaining = deserialize_shielded_output(remaining) + shielded_outputs.append(output) + + except InvalidShieldedOutputError: + raise + except (IndexError, struct.error, ValueError) as e: + raise InvalidShieldedOutputError(f'malformed shielded outputs header: {e}') from e + + return cls( + tx=tx, + shielded_outputs=shielded_outputs, + ), remaining + + def serialize(self) -> bytes: + """Serialize: header_id(1) | num_outputs(1) | outputs...""" + parts: list[bytes] = [] + parts.append(VertexHeaderId.SHIELDED_OUTPUTS_HEADER.value) + parts.append(int_to_bytes(len(self.shielded_outputs), 1)) + + for output in self.shielded_outputs: + parts.append(serialize_shielded_output(output)) + + return b''.join(parts) + + def get_sighash_bytes(self) -> bytes: + """Include in sighash: header_id + count + per-output sighash bytes.""" + parts: list[bytes] = [] + parts.append(VertexHeaderId.SHIELDED_OUTPUTS_HEADER.value) + parts.append(int_to_bytes(len(self.shielded_outputs), 1)) + + for output in self.shielded_outputs: + parts.append(output_sighash_bytes(output)) + + return b''.join(parts) diff --git a/hathor/transaction/headers/types.py b/hathor/transaction/headers/types.py index 7b45b8a8e..386ea23ab 100644 --- a/hathor/transaction/headers/types.py +++ b/hathor/transaction/headers/types.py @@ -19,3 +19,4 @@ class VertexHeaderId(Enum): NANO_HEADER = b'\x10' FEE_HEADER = b'\x11' + SHIELDED_OUTPUTS_HEADER = b'\x12' diff --git a/hathor/transaction/resources/create_tx.py b/hathor/transaction/resources/create_tx.py index dbc58af52..99b035740 100644 --- a/hathor/transaction/resources/create_tx.py +++ b/hathor/transaction/resources/create_tx.py @@ -18,6 +18,7 @@ from hathor.api_util import Resource, set_cors from hathor.crypto.util import decode_address from hathor.exception import InvalidNewTransaction +from hathor.feature_activation.utils import Features from hathor.manager import HathorManager from hathor.transaction import Transaction, TxInput, TxOutput from hathor.transaction.scripts import create_output_script @@ -118,7 +119,12 @@ def _verify_unsigned_skip_pow(self, tx: Transaction) -> None: verifiers.vertex.verify_sigops_output(tx, enable_checkdatasig_count=True) verifiers.tx.verify_sigops_input(tx, enable_checkdatasig_count=True) best_block = self.manager.tx_storage.get_best_block() - params = VerificationParams.default_for_mempool(best_block=best_block) + features = Features.from_vertex( + settings=self.manager._settings, + feature_service=self.manager.feature_service, + vertex=best_block, + ) + params = VerificationParams.default_for_mempool(best_block=best_block, features=features) # need to run verify_inputs first to check if all inputs exist verifiers.tx.verify_inputs(tx, params, skip_script=True) verifiers.vertex.verify_parents(tx) diff --git a/hathor/transaction/scripts/execute.py b/hathor/transaction/scripts/execute.py index 1b393712d..5511ad03c 100644 --- a/hathor/transaction/scripts/execute.py +++ b/hathor/transaction/scripts/execute.py @@ -115,9 +115,15 @@ def script_eval(tx: Transaction, txin: TxInput, spent_tx: BaseTransaction, versi :raises ScriptError: if script verification fails """ + # VULN-002 / CONS-007: Use resolve_spent_output for shielded-aware lookup + try: + resolved = spent_tx.resolve_spent_output(txin.index) + except IndexError: + raise InvalidScriptError(f'input index {txin.index} out of range') + output_script = resolved.script raw_script_eval( input_data=txin.data, - output_script=spent_tx.outputs[txin.index].script, + output_script=output_script, extras=UtxoScriptExtras(tx=tx, txin=txin, spent_tx=spent_tx, version=version), ) diff --git a/hathor/transaction/scripts/opcode.py b/hathor/transaction/scripts/opcode.py index 0f39424fb..95f233263 100644 --- a/hathor/transaction/scripts/opcode.py +++ b/hathor/transaction/scripts/opcode.py @@ -515,7 +515,14 @@ def op_find_p2pkh(context: ScriptContext) -> None: spent_tx = context.extras.spent_tx txin = context.extras.txin tx = context.extras.tx - contract_value = spent_tx.outputs[txin.index].value + # CONS-020: use resolve_spent_output for shielded-aware lookup + from hathor.transaction.shielded_tx_output import OutputMode + resolved = spent_tx.resolve_spent_output(txin.index) + if resolved.mode() != OutputMode.TRANSPARENT: + raise VerifyFailed + from hathor.transaction import TxOutput + assert isinstance(resolved, TxOutput) + contract_value = resolved.value address = context.stack.pop() address_b58 = get_address_b58_from_bytes(address) diff --git a/hathor/transaction/shielded_output_secrets.py b/hathor/transaction/shielded_output_secrets.py new file mode 100644 index 000000000..b71423908 --- /dev/null +++ b/hathor/transaction/shielded_output_secrets.py @@ -0,0 +1,28 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataclass holding the secrets recovered from a shielded output via range proof rewind.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(slots=True, frozen=True) +class ShieldedOutputSecrets: + """Secrets recovered from a shielded output via ECDH + range proof rewind.""" + amount: int # Committed value + blinding_factor: bytes # 32B value blinding factor + token_uid: bytes # 32B token UID + asset_blinding_factor: bytes | None # 32B for FullShielded, None for AmountShielded diff --git a/hathor/transaction/shielded_tx_output.py b/hathor/transaction/shielded_tx_output.py new file mode 100644 index 000000000..5d9b6d799 --- /dev/null +++ b/hathor/transaction/shielded_tx_output.py @@ -0,0 +1,272 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +from dataclasses import dataclass +from enum import IntEnum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + +COMMITMENT_SIZE = 33 +ASSET_COMMITMENT_SIZE = 33 +EPHEMERAL_PUBKEY_SIZE = 33 # Compressed secp256k1 public key +MAX_RANGE_PROOF_SIZE = 1024 # Valid Bulletproofs are ~675 bytes +MAX_SURJECTION_PROOF_SIZE = 4096 # Surjection proofs grow with input count +MAX_SHIELDED_OUTPUTS = 32 # Maximum number of shielded outputs per transaction +MAX_SHIELDED_OUTPUT_SCRIPT_SIZE = 1024 # Match settings.MAX_OUTPUT_SCRIPT_SIZE (VULN-001) + + +class OutputMode(IntEnum): + """Privacy level for an output.""" + TRANSPARENT = 0 # Standard TxOutput: amount, token ID, and script all visible + AMOUNT_ONLY = 1 # Amount hidden, token ID visible (no surjection proof) + FULLY_SHIELDED = 2 # Both amount and token ID hidden (surjection proof required) + + +@dataclass(slots=True, frozen=True) +class AmountShieldedOutput: + """Amount hidden, token ID visible. No surjection proof needed.""" + commitment: bytes # 33B Pedersen commitment (C = amount*H_token + r*G) + range_proof: bytes # ~675B Bulletproof + script: bytes # Locking script + token_data: int # Token index (like TxOutput.token_data) + ephemeral_pubkey: bytes = b'' # 33B compressed secp256k1 pubkey for ECDH recovery + + @staticmethod + def mode() -> OutputMode: + return OutputMode.AMOUNT_ONLY + + +@dataclass(slots=True, frozen=True) +class FullShieldedOutput: + """Both amount and token type hidden. Surjection proof required.""" + commitment: bytes # 33B Pedersen commitment + range_proof: bytes # ~675B Bulletproof + script: bytes # Locking script + asset_commitment: bytes # 33B blinded asset tag (A = H_token + r_asset*G) + surjection_proof: bytes # Variable, asset surjection proof + ephemeral_pubkey: bytes = b'' # 33B compressed secp256k1 pubkey for ECDH recovery + + @staticmethod + def mode() -> OutputMode: + return OutputMode.FULLY_SHIELDED + + +@dataclass(slots=True, frozen=True) +class ShieldedOutputSecrets: + """Recovered secrets from a shielded output via ECDH rewind.""" + value: int + blinding_factor: bytes + message: bytes + token_uid: bytes # Recovered or derived token UID + + +# Union type for headers and verifiers +ShieldedOutput = AmountShieldedOutput | FullShieldedOutput + + +def recover_shielded_secrets( + output: ShieldedOutput, + private_key_bytes: bytes, + get_token_uid: 'Callable[[int], bytes]', +) -> ShieldedOutputSecrets: + """Recover hidden values from a shielded output using ECDH + range proof rewind. + + Args: + output: The shielded output to recover secrets from. + private_key_bytes: The 32-byte secret key for ECDH. + get_token_uid: Callback to resolve token_data index to token UID (e.g., tx.get_token_uid). + + Returns: + ShieldedOutputSecrets with the recovered value, blinding factor, message, and token UID. + + Raises: + ValueError: If ECDH recovery fails or the output has no ephemeral pubkey. + """ + # TODO: Use ECDH shared secret derivation (derive_ecdh_shared_secret, derive_rewind_nonce + # from hathor.crypto.shielded.ecdh) with the output's ephemeral_pubkey and the recipient's + # private key. Then determine the generator (derive_asset_tag for AmountShielded, or + # asset_commitment for FullShielded). Finally call rewind_range_proof from + # hathor.crypto.shielded to extract (value, blinding_factor, message). For FullShieldedOutput, + # the token UID is embedded in the first 32 bytes of the recovered message. + raise NotImplementedError('requires hathor-ct-crypto library') + + +def serialize_shielded_output(output: ShieldedOutput) -> bytes: + """Serialize a shielded output to bytes. + + Format: + mode(1) | commitment(33) | rp_len(2) | range_proof(var) | script_len(2) | script(var) | + [if AMOUNT_ONLY]: token_data(1) + [if FULLY_SHIELDED]: asset_commitment(33) | sp_len(2) | surjection_proof(var) + """ + parts: list[bytes] = [] + parts.append(struct.pack('!B', output.mode())) + parts.append(output.commitment) + parts.append(struct.pack('!H', len(output.range_proof))) + parts.append(output.range_proof) + parts.append(struct.pack('!H', len(output.script))) + parts.append(output.script) + + if isinstance(output, AmountShieldedOutput): + parts.append(struct.pack('!B', output.token_data)) + elif isinstance(output, FullShieldedOutput): + parts.append(output.asset_commitment) + parts.append(struct.pack('!H', len(output.surjection_proof))) + parts.append(output.surjection_proof) + + # Ephemeral pubkey for ECDH-based recovery (always 33B; zeros = not present) + parts.append(output.ephemeral_pubkey if output.ephemeral_pubkey else b'\x00' * EPHEMERAL_PUBKEY_SIZE) + + return b''.join(parts) + + +def deserialize_shielded_output(buf: bytes | memoryview) -> tuple[ShieldedOutput, bytes]: + """Deserialize a shielded output from bytes. + + Returns (output, remaining_bytes). + """ + view = memoryview(buf) if not isinstance(buf, memoryview) else buf + offset = 0 + + mode_byte = view[offset] + offset += 1 + mode = OutputMode(mode_byte) + + commitment = bytes(view[offset:offset + COMMITMENT_SIZE]) + offset += COMMITMENT_SIZE + if len(commitment) != COMMITMENT_SIZE: + raise ValueError( + f'truncated commitment: expected {COMMITMENT_SIZE} bytes, got {len(commitment)}' + ) + + (rp_len,) = struct.unpack_from('!H', view, offset) + offset += 2 + if rp_len > MAX_RANGE_PROOF_SIZE: + raise ValueError( + f'range proof size {rp_len} exceeds maximum {MAX_RANGE_PROOF_SIZE}' + ) + range_proof = bytes(view[offset:offset + rp_len]) + offset += rp_len + if len(range_proof) != rp_len: + raise ValueError( + f'truncated range proof: expected {rp_len} bytes, got {len(range_proof)}' + ) + + (script_len,) = struct.unpack_from('!H', view, offset) + offset += 2 + if script_len > MAX_SHIELDED_OUTPUT_SCRIPT_SIZE: + raise ValueError( + f'script size {script_len} exceeds maximum {MAX_SHIELDED_OUTPUT_SCRIPT_SIZE}' + ) + script = bytes(view[offset:offset + script_len]) + offset += script_len + if len(script) != script_len: + raise ValueError( + f'truncated script: expected {script_len} bytes, got {len(script)}' + ) + + if mode == OutputMode.AMOUNT_ONLY: + token_data = view[offset] + offset += 1 + + # Read ephemeral pubkey (always 33B; zeros = not present) + raw_ephemeral = bytes(view[offset:offset + EPHEMERAL_PUBKEY_SIZE]) + offset += EPHEMERAL_PUBKEY_SIZE + if len(raw_ephemeral) != EPHEMERAL_PUBKEY_SIZE: + raise ValueError( + f'truncated ephemeral_pubkey: expected {EPHEMERAL_PUBKEY_SIZE} bytes, ' + f'got {len(raw_ephemeral)}' + ) + ephemeral_pubkey = b'' if raw_ephemeral == b'\x00' * EPHEMERAL_PUBKEY_SIZE else raw_ephemeral + + output: ShieldedOutput = AmountShieldedOutput( + commitment=commitment, + range_proof=range_proof, + script=script, + token_data=token_data, + ephemeral_pubkey=ephemeral_pubkey, + ) + elif mode == OutputMode.FULLY_SHIELDED: + asset_commitment = bytes(view[offset:offset + ASSET_COMMITMENT_SIZE]) + offset += ASSET_COMMITMENT_SIZE + if len(asset_commitment) != ASSET_COMMITMENT_SIZE: + raise ValueError( + f'truncated asset_commitment: expected {ASSET_COMMITMENT_SIZE} bytes, ' + f'got {len(asset_commitment)}' + ) + + (sp_len,) = struct.unpack_from('!H', view, offset) + offset += 2 + if sp_len > MAX_SURJECTION_PROOF_SIZE: + raise ValueError( + f'surjection proof size {sp_len} exceeds maximum {MAX_SURJECTION_PROOF_SIZE}' + ) + surjection_proof = bytes(view[offset:offset + sp_len]) + offset += sp_len + if len(surjection_proof) != sp_len: + raise ValueError( + f'truncated surjection proof: expected {sp_len} bytes, got {len(surjection_proof)}' + ) + + # Read ephemeral pubkey (always 33B; zeros = not present) + raw_ephemeral = bytes(view[offset:offset + EPHEMERAL_PUBKEY_SIZE]) + offset += EPHEMERAL_PUBKEY_SIZE + if len(raw_ephemeral) != EPHEMERAL_PUBKEY_SIZE: + raise ValueError( + f'truncated ephemeral_pubkey: expected {EPHEMERAL_PUBKEY_SIZE} bytes, ' + f'got {len(raw_ephemeral)}' + ) + ephemeral_pubkey = b'' if raw_ephemeral == b'\x00' * EPHEMERAL_PUBKEY_SIZE else raw_ephemeral + + output = FullShieldedOutput( + commitment=commitment, + range_proof=range_proof, + script=script, + asset_commitment=asset_commitment, + surjection_proof=surjection_proof, + ephemeral_pubkey=ephemeral_pubkey, + ) + else: + raise ValueError(f'Unknown shielded output mode: {mode_byte}') + + return output, bytes(view[offset:]) + + +def get_sighash_bytes(output: ShieldedOutput) -> bytes: + """Return sighash bytes for a shielded output. + + Includes commitment + mode + token_data/asset_commitment + script. + Does NOT include proofs (range_proof, surjection_proof). + """ + parts: list[bytes] = [] + parts.append(struct.pack('!B', output.mode())) + parts.append(output.commitment) + + if isinstance(output, AmountShieldedOutput): + parts.append(struct.pack('!B', output.token_data)) + elif isinstance(output, FullShieldedOutput): + parts.append(output.asset_commitment) + + parts.append(output.script) + + # Always include ephemeral pubkey in sighash to prevent malleability + # where someone strips the ephemeral pubkey. Use zero bytes if not present. + parts.append(output.ephemeral_pubkey if output.ephemeral_pubkey else b'\x00' * EPHEMERAL_PUBKEY_SIZE) + + return b''.join(parts) diff --git a/hathor/transaction/token_info.py b/hathor/transaction/token_info.py index c713133d1..df082f10d 100644 --- a/hathor/transaction/token_info.py +++ b/hathor/transaction/token_info.py @@ -75,7 +75,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.fees_from_fee_header: int = 0 - def calculate_fee(self, settings: 'HathorSettings') -> int: + def calculate_fee(self, settings: 'HathorSettings', *, shielded_fee: int = 0) -> int: """ Calculate the total fee based on the number of chargeable outputs and inputs for each token in the transaction. @@ -87,15 +87,17 @@ def calculate_fee(self, settings: 'HathorSettings') -> int: as `chargeable_outputs * settings.FEE_PER_OUTPUT`. - If a token has zero chargeable outputs but one or more chargeable inputs, a flat fee of `settings.FEE_PER_OUTPUT` is applied. + - An additional shielded_fee is added for shielded outputs. Args: settings (HathorSettings): The configuration object containing fee-related parameters, such as `FEE_PER_OUTPUT`. + shielded_fee: Additional fee for shielded outputs (default 0). Returns: int: The total transaction fee """ - fee = 0 + fee = shielded_fee for token_uid, token_info in self.items(): if token_info.chargeable_outputs > 0: diff --git a/hathor/transaction/transaction.py b/hathor/transaction/transaction.py index f2336bf35..e9dc3a79e 100644 --- a/hathor/transaction/transaction.py +++ b/hathor/transaction/transaction.py @@ -26,8 +26,9 @@ from hathor.transaction import TxInput, TxOutput, TxVersion from hathor.transaction.base_transaction import TX_HASH_SIZE, GenericVertex from hathor.transaction.exceptions import InvalidToken -from hathor.transaction.headers import NanoHeader, VertexBaseHeader +from hathor.transaction.headers import NanoHeader, ShieldedOutputsHeader, VertexBaseHeader from hathor.transaction.headers.fee_header import FeeHeader +from hathor.transaction.shielded_tx_output import ShieldedOutput from hathor.transaction.static_metadata import TransactionStaticMetadata from hathor.transaction.token_info import TokenInfo, TokenInfoDict, TokenVersion, get_token_version from hathor.transaction.util import VerboseCallback, unpack, unpack_len @@ -133,6 +134,26 @@ def get_fee_header(self) -> FeeHeader: """Return the FeeHeader or raise ValueError.""" return self._get_header(FeeHeader) + def has_shielded_outputs(self) -> bool: + """Returns true if this transaction has a shielded outputs header.""" + try: + self.get_shielded_outputs_header() + except ValueError: + return False + else: + return True + + def get_shielded_outputs_header(self) -> ShieldedOutputsHeader: + """Return the ShieldedOutputsHeader or raise ValueError.""" + return self._get_header(ShieldedOutputsHeader) + + @property + def shielded_outputs(self) -> list[ShieldedOutput]: + """Return the list of shielded outputs, or empty list if no header.""" + if self.has_shielded_outputs(): + return self.get_shielded_outputs_header().shielded_outputs + return [] + def _get_header(self, header_type: type[T]) -> T: """Return the header of the given type or raise ValueError.""" for header in self.headers: @@ -410,8 +431,23 @@ def _get_token_info_from_inputs(self, nc_block_storage: NCBlockStorage) -> Token for tx_input in self.inputs: spent_tx = self.get_spent_tx(tx_input) - spent_output = spent_tx.outputs[tx_input.index] + # CONS-002: Use resolve_spent_output for shielded-aware lookup. + # Shielded inputs are skipped for token accounting — their amounts + # are verified by the homomorphic balance equation instead. + try: + resolved = spent_tx.resolve_spent_output(tx_input.index) + except IndexError: + # Out of bounds — will be caught by _verify_inputs + continue + + from hathor.transaction.shielded_tx_output import OutputMode + if resolved.mode() != OutputMode.TRANSPARENT: + # Shielded input: skip for token info (amount is hidden) + continue + + assert isinstance(resolved, TxOutput) + spent_output = resolved token_uid = spent_tx.get_token_uid(spent_output.get_token_index()) token_version = get_token_version(self.storage, nc_block_storage, token_uid) diff --git a/hathor/transaction/vertex_parser.py b/hathor/transaction/vertex_parser.py index 85850a18a..8016a3fde 100644 --- a/hathor/transaction/vertex_parser.py +++ b/hathor/transaction/vertex_parser.py @@ -17,7 +17,7 @@ from struct import error as StructError from typing import TYPE_CHECKING, Type -from hathor.transaction.headers import FeeHeader, NanoHeader, VertexBaseHeader, VertexHeaderId +from hathor.transaction.headers import FeeHeader, NanoHeader, ShieldedOutputsHeader, VertexBaseHeader, VertexHeaderId if TYPE_CHECKING: from hathor.conf.settings import HathorSettings @@ -39,6 +39,8 @@ def get_supported_headers(settings: HathorSettings) -> dict[VertexHeaderId, Type supported_headers[VertexHeaderId.NANO_HEADER] = NanoHeader if settings.ENABLE_FEE_BASED_TOKENS: supported_headers[VertexHeaderId.FEE_HEADER] = FeeHeader + if settings.ENABLE_SHIELDED_TRANSACTIONS: + supported_headers[VertexHeaderId.SHIELDED_OUTPUTS_HEADER] = ShieldedOutputsHeader return supported_headers @staticmethod diff --git a/hathor/verification/shielded_transaction_verifier.py b/hathor/verification/shielded_transaction_verifier.py new file mode 100644 index 000000000..92180c2b1 --- /dev/null +++ b/hathor/verification/shielded_transaction_verifier.py @@ -0,0 +1,211 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from structlog import get_logger + +from hathor.transaction.exceptions import ( + InvalidShieldedOutputError, + ShieldedAuthorityError, + ShieldedMintMeltForbiddenError, + TrivialCommitmentError, +) +from hathor.transaction.shielded_tx_output import AmountShieldedOutput, FullShieldedOutput +from hathor.transaction.token_info import TokenInfoDict, TokenVersion + +if TYPE_CHECKING: + from hathor.conf.settings import HathorSettings + from hathor.transaction.transaction import Transaction + + +_CRYPTO_TOKEN_UID_SIZE = 32 + + +def _normalize_token_uid(token_uid: bytes) -> bytes: + """Normalize a token UID to 32 bytes for the crypto library. + + Hathor uses b'\\x00' (1 byte) for HTR and 32-byte hashes for custom tokens. + The crypto library always expects 32-byte token UIDs. + """ + if len(token_uid) == _CRYPTO_TOKEN_UID_SIZE: + return token_uid + if len(token_uid) == 1: + return token_uid.ljust(_CRYPTO_TOKEN_UID_SIZE, b'\x00') + raise InvalidShieldedOutputError( + f'invalid token UID length: expected 1 or {_CRYPTO_TOKEN_UID_SIZE} bytes, got {len(token_uid)}' + ) + + +logger = get_logger() + + +class ShieldedTransactionVerifier: + __slots__ = ('_settings', 'log') + + def __init__(self, *, settings: HathorSettings) -> None: + self._settings = settings + self.log = logger.new() + + @staticmethod + def calculate_shielded_fee(settings: HathorSettings, tx: Transaction) -> int: + """Calculate the total fee required for shielded outputs.""" + fee = 0 + for output in tx.shielded_outputs: + if isinstance(output, AmountShieldedOutput): + fee += settings.FEE_PER_AMOUNT_SHIELDED_OUTPUT + elif isinstance(output, FullShieldedOutput): + fee += settings.FEE_PER_FULL_SHIELDED_OUTPUT + return fee + + def verify_shielded_fee(self, tx: Transaction) -> None: + """Verify the transaction declares sufficient fees for its shielded outputs.""" + if not tx.has_fees(): + raise InvalidShieldedOutputError('shielded transactions require a fee header') + fee_header = tx.get_fee_header() + expected_shielded_fee = self.calculate_shielded_fee(self._settings, tx) + total_declared_fee = fee_header.total_fee_amount() + if total_declared_fee < expected_shielded_fee: + raise InvalidShieldedOutputError( + f'insufficient fee for shielded outputs: declared {total_declared_fee}, ' + f'minimum shielded fee is {expected_shielded_fee}' + ) + + def verify_no_mint_melt(self, token_dict: TokenInfoDict) -> None: + """Reject mint/melt operations in transactions with shielded outputs. + + The homomorphic balance equation enforces conservation (inputs = outputs). + Minting/melting breaks this equation. Additionally, the deposit calculation + in verify_token_rules only tracks transparent flows, so it would be incorrect + for shielded minting. + + Authority pass-through (amount=0 with can_mint/can_melt) is still allowed. + """ + for token_uid, token_info in token_dict.items(): + if token_info.version == TokenVersion.NATIVE: + continue + if token_info.can_mint and token_info.has_been_minted(): + raise ShieldedMintMeltForbiddenError( + f'token {token_uid.hex()}: minting is not allowed in transactions ' + f'with shielded outputs (transparent surplus: {token_info.amount})' + ) + if token_info.can_melt and token_info.has_been_melted(): + raise ShieldedMintMeltForbiddenError( + f'token {token_uid.hex()}: melting is not allowed in transactions ' + f'with shielded outputs (transparent deficit: {token_info.amount})' + ) + + def verify_shielded_outputs(self, tx: Transaction) -> None: + """Top-level: calls all checks.""" + self.verify_commitments_valid(tx) + self.verify_authority_restriction(tx) # VULN-004: must run before range_proofs + self.verify_range_proofs(tx) + self.verify_trivial_commitment_protection(tx) + self.verify_shielded_fee(tx) + + def verify_shielded_outputs_with_storage(self, tx: Transaction) -> None: + """Shielded verifications that need storage (balance, surjection, trivial commitment).""" + self.verify_surjection_proofs(tx) + self.verify_shielded_balance(tx) + self._verify_trivial_commitment_with_storage(tx) + + def _verify_trivial_commitment_with_storage(self, tx: Transaction) -> None: + """VULN-008: Storage-aware trivial commitment protection. + + If all inputs are transparent, require >= 2 shielded outputs. + If any input is shielded, allow 1 shielded output. + """ + if not tx.shielded_outputs: + return + if self._has_shielded_input(tx): + return # Relaxed: shielded inputs already provide mixing + if len(tx.shielded_outputs) < 2: + raise TrivialCommitmentError( + 'when all inputs are transparent, at least 2 shielded outputs are required ' + f'to prevent trivial commitment matching (got {len(tx.shielded_outputs)})' + ) + + def verify_commitments_valid(self, tx: Transaction) -> None: + """Validate all commitments are exactly 33 bytes, valid curve points, and count is within limits.""" + # TODO: Verify output count <= MAX_SHIELDED_OUTPUTS. For each shielded output, check + # commitment size == COMMITMENT_SIZE (33B) and call validate_commitment() from + # hathor.crypto.shielded to ensure it's a valid secp256k1 curve point (VULN-007). + # For FullShieldedOutput, also check asset_commitment size == ASSET_COMMITMENT_SIZE + # and call validate_generator(). Validate ephemeral_pubkey size and curve point validity. + pass + + def verify_range_proofs(self, tx: Transaction) -> None: + """Rule 5: Every shielded output must have valid Bulletproof range proof.""" + # TODO: For each shielded output, derive the generator: for AmountShieldedOutput use + # derive_asset_tag(token_uid) from hathor.crypto.shielded; for FullShieldedOutput use + # output.asset_commitment. Then call verify_range_proof(proof, commitment, generator) + # to validate the Bulletproof range proof (proves amount in [0, 2^64)). + pass + + def verify_surjection_proofs(self, tx: Transaction) -> None: + """Rule 6: Only FullShieldedOutput instances require surjection proofs.""" + # TODO: Build domain of input asset generators: for transparent inputs use + # derive_asset_tag(token_uid), for shielded inputs use asset_commitment (FullShielded) + # or derive_asset_tag (AmountShielded). Then for each FullShieldedOutput, call + # verify_surjection_proof(proof, asset_commitment, domain_generators) from + # hathor.crypto.shielded to prove the output's token type is one of the inputs. + pass + + def verify_shielded_balance(self, tx: Transaction) -> None: + """Homomorphic balance verification. sum(C_in) == sum(C_out) + fee*H_HTR""" + # TODO: Collect transparent inputs/outputs as (value, token_uid) pairs and shielded + # inputs/outputs as commitment bytes. Append fee entries as transparent outputs. + # Call verify_balance(transparent_inputs, shielded_inputs, transparent_outputs, + # shielded_outputs) from hathor.crypto.shielded to check the homomorphic balance + # equation: sum(C_in) == sum(C_out) + fee*H_HTR. + pass + + def verify_authority_restriction(self, tx: Transaction) -> None: + """Rule 7: Shielded outputs cannot be authority (mint/melt) outputs.""" + for i, output in enumerate(tx.shielded_outputs): + if isinstance(output, AmountShieldedOutput): + # Check if token_data has authority bits set + from hathor.transaction import TxOutput + if output.token_data & TxOutput.TOKEN_AUTHORITY_MASK: + raise ShieldedAuthorityError( + f'shielded output {i}: authority outputs cannot be shielded' + ) + + def verify_trivial_commitment_protection(self, tx: Transaction) -> None: + """Rule 4: Without storage, conservatively require >= 2 shielded outputs always. + + VULN-008: The storage-less version cannot determine if inputs are shielded, + so it conservatively requires >= 2 shielded outputs in all cases. + The storage-aware version (_has_shielded_input) can relax this. + """ + if not tx.shielded_outputs: + return + + if len(tx.shielded_outputs) < 2: + raise TrivialCommitmentError( + 'at least 2 shielded outputs are required ' + f'to prevent trivial commitment matching (got {len(tx.shielded_outputs)})' + ) + + def _has_shielded_input(self, tx: Transaction) -> bool: + """Check if any input references a shielded output (requires storage).""" + assert tx.storage is not None + for tx_input in tx.inputs: + spent_tx = tx.storage.get_transaction(tx_input.tx_id) + if tx_input.index >= len(spent_tx.outputs): + # Index beyond transparent outputs → references shielded output + return True + return False diff --git a/hathor/verification/transaction_verifier.py b/hathor/verification/transaction_verifier.py index 999b319cf..1c34c86bb 100644 --- a/hathor/verification/transaction_verifier.py +++ b/hathor/verification/transaction_verifier.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, assert_never +from structlog import get_logger + from hathor.daa import DifficultyAdjustmentAlgorithm from hathor.feature_activation.feature_service import FeatureService from hathor.profiler import get_cpu_profiler @@ -34,10 +36,13 @@ InputVoidedAndConfirmed, InvalidInputData, InvalidInputDataSize, + InvalidShieldedOutputError, InvalidToken, InvalidVersionError, RewardLocked, ScriptError, + ShieldedAuthorityError, + ShieldedMintMeltForbiddenError, TimestampError, TokenNotFound, TooFewInputs, @@ -46,6 +51,7 @@ TooManySigOps, TooManyTokens, TooManyWithinConflicts, + TrivialCommitmentError, UnusedTokensError, WeightError, ) @@ -59,13 +65,15 @@ cpu = get_cpu_profiler() +logger = get_logger() + MAX_TOKENS_LENGTH: int = 16 MAX_WITHIN_CONFLICTS: int = 8 MAX_BETWEEN_CONFLICTS: int = 8 class TransactionVerifier: - __slots__ = ('_settings', '_daa', '_feature_service') + __slots__ = ('_settings', '_daa', '_feature_service', 'log') def __init__( self, @@ -77,6 +85,7 @@ def __init__( self._settings = settings self._daa = daa self._feature_service = feature_service + self.log = logger.new() def verify_parents_basic(self, tx: Transaction) -> None: """Verify number and non-duplicity of parents.""" @@ -119,10 +128,20 @@ def verify_sigops_input(self, tx: Transaction, enable_checkdatasig_count: bool = spent_tx = tx.get_spent_tx(tx_input) except TransactionDoesNotExist: raise InexistentInput('Input tx does not exist: {}'.format(tx_input.tx_id.hex())) - if tx_input.index >= len(spent_tx.outputs): + # VULN-002: Handle shielded output references + if tx_input.index < len(spent_tx.outputs): + script = spent_tx.outputs[tx_input.index].script + elif spent_tx.shielded_outputs: + shielded_idx = tx_input.index - len(spent_tx.outputs) + if shielded_idx < len(spent_tx.shielded_outputs): + script = spent_tx.shielded_outputs[shielded_idx].script + else: + raise InexistentInput('Output spent by this input does not exist: {} index {}'.format( + tx_input.tx_id.hex(), tx_input.index)) + else: raise InexistentInput('Output spent by this input does not exist: {} index {}'.format( tx_input.tx_id.hex(), tx_input.index)) - n_txops += counter.get_sigops_count(tx_input.data, spent_tx.outputs[tx_input.index].script) + n_txops += counter.get_sigops_count(tx_input.data, script) if n_txops > self._settings.MAX_TX_SIGOPS_INPUT: raise TooManySigOps( @@ -149,7 +168,19 @@ def _verify_inputs( )) spent_tx = tx.get_spent_tx(input_tx) - assert input_tx.index < len(spent_tx.outputs) + + # VULN-002: Handle shielded output references instead of asserting + if input_tx.index < len(spent_tx.outputs): + # Standard transparent output + pass + elif spent_tx.shielded_outputs: + shielded_idx = input_tx.index - len(spent_tx.outputs) + if shielded_idx >= len(spent_tx.shielded_outputs): + raise InexistentInput('Output spent by this input does not exist: {} index {}'.format( + input_tx.tx_id.hex(), input_tx.index)) + else: + raise InexistentInput('Output spent by this input does not exist: {} index {}'.format( + input_tx.tx_id.hex(), input_tx.index)) if tx.timestamp <= spent_tx.timestamp: raise TimestampError('tx={} timestamp={}, spent_tx={} timestamp={}'.format( @@ -261,6 +292,8 @@ def verify_sum( tx: Transaction, token_dict: TokenInfoDict, allow_nonexistent_tokens: bool = False, + *, + shielded_fee: int = 0, ) -> None: """Verify that the sum of outputs is equal of the sum of inputs, for each token. If sum of inputs and outputs is not 0, make sure inputs have mint/melt authority. @@ -321,7 +354,7 @@ def verify_sum( assert tx.is_nano_contract() return - expected_fee = token_dict.calculate_fee(settings) + expected_fee = token_dict.calculate_fee(settings, shielded_fee=shielded_fee) if expected_fee != token_dict.fees_from_fee_header: raise InputOutputMismatch(f"Fee amount is different than expected. " f"(amount={token_dict.fees_from_fee_header}, expected={expected_fee})") @@ -334,6 +367,69 @@ def verify_sum( assert htr_info.amount == htr_expected_amount + @classmethod + def verify_token_rules( + cls, + settings: HathorSettings, + token_dict: TokenInfoDict, + *, + shielded_fee: int = 0, + ) -> None: + """Verify token authority permissions, deposit requirements, and fee correctness. + + This method extracts the non-balance checks from verify_sum so they can be enforced + for shielded transactions too (where verify_sum's balance equation is replaced by + verify_shielded_balance, but these rules must still apply). + + :raises ForbiddenMint: if tokens were minted without authority + :raises ForbiddenMelt: if tokens were melted without authority + :raises InputOutputMismatch: if HTR deposit or fee amounts are incorrect + """ + deposit = 0 + withdraw = 0 + + for token_uid, token_info in token_dict.items(): + cls._check_token_permissions(token_uid, token_info) + match token_info.version: + case None: + # Nonexistent tokens are not expected here (shielded txs are not nanos) + pass + + case TokenVersion.NATIVE: + continue + + case TokenVersion.DEPOSIT: + if token_info.has_been_melted(): + withdraw += get_deposit_token_withdraw_amount(settings, token_info.amount) + if token_info.has_been_minted(): + deposit += get_deposit_token_deposit_amount(settings, token_info.amount) + + case TokenVersion.FEE: + continue + + case _: + assert_never(token_info.version) + + # check whether the deposit/withdraw amount is correct + htr_expected_amount = withdraw - deposit + htr_info = token_dict[settings.HATHOR_TOKEN_UID] + if htr_info.amount > htr_expected_amount: + raise InputOutputMismatch('There\'s an invalid surplus of HTR. (amount={}, expected={})'.format( + htr_info.amount, + htr_expected_amount, + )) + + expected_fee = token_dict.calculate_fee(settings, shielded_fee=shielded_fee) + if expected_fee != token_dict.fees_from_fee_header: + raise InputOutputMismatch(f"Fee amount is different than expected. " + f"(amount={token_dict.fees_from_fee_header}, expected={expected_fee})") + + if htr_info.amount < htr_expected_amount: + raise InputOutputMismatch('There\'s an invalid deficit of HTR. (amount={}, expected={})'.format( + htr_info.amount, + htr_expected_amount, + )) + @staticmethod def _check_token_permissions(token_uid: TokenUid, token_info: TokenInfo) -> None: """Verify whether token can be minted/melted based on its authority.""" @@ -377,6 +473,12 @@ def verify_tokens(self, tx: Transaction, params: VerificationParams) -> None: for txout in tx.outputs: seen_token_indexes.add(txout.get_token_index()) + # VULN-013: Consider shielded output token indexes + from hathor.transaction.shielded_tx_output import AmountShieldedOutput + for shielded_out in tx.shielded_outputs: + if isinstance(shielded_out, AmountShieldedOutput): + seen_token_indexes.add(shielded_out.token_data & 0x7F) + if tx.is_nano_contract(): nano_header = tx.get_nano_header() for action in nano_header.nc_actions: @@ -424,3 +526,151 @@ def verify_conflict(self, tx: Transaction, params: VerificationParams) -> None: if between_counter > MAX_BETWEEN_CONFLICTS: raise TooManyBetweenConflicts + + # --- Shielded transaction verification methods --- + + _CRYPTO_TOKEN_UID_SIZE = 32 + + @staticmethod + def _normalize_token_uid(token_uid: bytes) -> bytes: + """Normalize a token UID to 32 bytes for the crypto library.""" + if len(token_uid) == TransactionVerifier._CRYPTO_TOKEN_UID_SIZE: + return token_uid + if len(token_uid) == 1: + return token_uid.ljust(TransactionVerifier._CRYPTO_TOKEN_UID_SIZE, b'\x00') + raise InvalidShieldedOutputError( + f'invalid token UID length: expected 1 or {TransactionVerifier._CRYPTO_TOKEN_UID_SIZE} bytes, ' + f'got {len(token_uid)}' + ) + + @staticmethod + def calculate_shielded_fee(settings: HathorSettings, tx: Transaction) -> int: + """Calculate the total fee required for shielded outputs.""" + from hathor.transaction.shielded_tx_output import AmountShieldedOutput, FullShieldedOutput + fee = 0 + for output in tx.shielded_outputs: + if isinstance(output, AmountShieldedOutput): + fee += settings.FEE_PER_AMOUNT_SHIELDED_OUTPUT + elif isinstance(output, FullShieldedOutput): + fee += settings.FEE_PER_FULL_SHIELDED_OUTPUT + return fee + + def verify_shielded_fee(self, tx: Transaction) -> None: + """Verify the transaction declares sufficient fees for its shielded outputs.""" + if not tx.has_fees(): + raise InvalidShieldedOutputError('shielded transactions require a fee header') + fee_header = tx.get_fee_header() + expected_shielded_fee = self.calculate_shielded_fee(self._settings, tx) + total_declared_fee = fee_header.total_fee_amount() + if total_declared_fee < expected_shielded_fee: + raise InvalidShieldedOutputError( + f'insufficient fee for shielded outputs: declared {total_declared_fee}, ' + f'minimum shielded fee is {expected_shielded_fee}' + ) + + def verify_no_mint_melt(self, token_dict: TokenInfoDict) -> None: + """Reject mint/melt operations in transactions with shielded outputs.""" + for token_uid, token_info in token_dict.items(): + if token_info.version == TokenVersion.NATIVE: + continue + if token_info.can_mint and token_info.has_been_minted(): + raise ShieldedMintMeltForbiddenError( + f'token {token_uid.hex()}: minting is not allowed in transactions ' + f'with shielded outputs (transparent surplus: {token_info.amount})' + ) + if token_info.can_melt and token_info.has_been_melted(): + raise ShieldedMintMeltForbiddenError( + f'token {token_uid.hex()}: melting is not allowed in transactions ' + f'with shielded outputs (transparent deficit: {token_info.amount})' + ) + + def verify_shielded_outputs(self, tx: Transaction) -> None: + """Top-level: calls all basic shielded checks.""" + self.verify_commitments_valid(tx) + self.verify_authority_restriction(tx) + self.verify_range_proofs(tx) + self.verify_trivial_commitment_protection(tx) + self.verify_shielded_fee(tx) + + def verify_shielded_outputs_with_storage(self, tx: Transaction) -> None: + """Shielded verifications that need storage (balance, surjection, trivial commitment).""" + self.verify_surjection_proofs(tx) + self.verify_shielded_balance(tx) + self._verify_trivial_commitment_with_storage(tx) + + def _verify_trivial_commitment_with_storage(self, tx: Transaction) -> None: + """VULN-008: Storage-aware trivial commitment protection.""" + if not tx.shielded_outputs: + return + if self._has_shielded_input(tx): + return + if len(tx.shielded_outputs) < 2: + raise TrivialCommitmentError( + 'when all inputs are transparent, at least 2 shielded outputs are required ' + f'to prevent trivial commitment matching (got {len(tx.shielded_outputs)})' + ) + + def verify_commitments_valid(self, tx: Transaction) -> None: + """Validate all commitments are exactly 33 bytes, valid curve points, and count is within limits.""" + # TODO: Verify output count <= MAX_SHIELDED_OUTPUTS. For each shielded output, check + # commitment size == COMMITMENT_SIZE (33B) and call validate_commitment() from + # hathor.crypto.shielded to ensure it's a valid secp256k1 curve point (VULN-007). + # For FullShieldedOutput, also check asset_commitment size == ASSET_COMMITMENT_SIZE + # and call validate_generator(). Validate ephemeral_pubkey size and curve point validity. + pass + + def verify_range_proofs(self, tx: Transaction) -> None: + """Every shielded output must have valid Bulletproof range proof.""" + # TODO: For each shielded output, derive the generator: for AmountShieldedOutput use + # derive_asset_tag(token_uid) from hathor.crypto.shielded; for FullShieldedOutput use + # output.asset_commitment. Then call verify_range_proof(proof, commitment, generator) + # to validate the Bulletproof range proof (proves amount in [0, 2^64)). + pass + + def verify_surjection_proofs(self, tx: Transaction) -> None: + """Only FullShieldedOutput instances require surjection proofs.""" + # TODO: Build domain of input asset generators: for transparent inputs use + # derive_asset_tag(token_uid), for shielded inputs use asset_commitment (FullShielded) + # or derive_asset_tag (AmountShielded). Then for each FullShieldedOutput, call + # verify_surjection_proof(proof, asset_commitment, domain_generators) from + # hathor.crypto.shielded to prove the output's token type is one of the inputs. + pass + + def verify_shielded_balance(self, tx: Transaction) -> None: + """Homomorphic balance verification: sum(C_in) == sum(C_out) + fee*H_HTR.""" + # TODO: Collect transparent inputs/outputs as (value, token_uid) pairs and shielded + # inputs/outputs as commitment bytes. Append fee entries as transparent outputs. + # Call verify_balance(transparent_inputs, shielded_inputs, transparent_outputs, + # shielded_outputs) from hathor.crypto.shielded to check the homomorphic balance + # equation: sum(C_in) == sum(C_out) + fee*H_HTR. + pass + + def verify_authority_restriction(self, tx: Transaction) -> None: + """Shielded outputs cannot be authority (mint/melt) outputs.""" + from hathor.transaction.shielded_tx_output import AmountShieldedOutput + for i, output in enumerate(tx.shielded_outputs): + if isinstance(output, AmountShieldedOutput): + from hathor.transaction import TxOutput + if output.token_data & TxOutput.TOKEN_AUTHORITY_MASK: + raise ShieldedAuthorityError( + f'shielded output {i}: authority outputs cannot be shielded' + ) + + def verify_trivial_commitment_protection(self, tx: Transaction) -> None: + """Without storage, conservatively require >= 2 shielded outputs always.""" + if not tx.shielded_outputs: + return + if len(tx.shielded_outputs) < 2: + raise TrivialCommitmentError( + 'at least 2 shielded outputs are required ' + f'to prevent trivial commitment matching (got {len(tx.shielded_outputs)})' + ) + + def _has_shielded_input(self, tx: Transaction) -> bool: + """Check if any input references a shielded output (requires storage).""" + assert tx.storage is not None + for tx_input in tx.inputs: + spent_tx = tx.storage.get_transaction(tx_input.tx_id) + if tx_input.index >= len(spent_tx.outputs): + return True + return False diff --git a/hathor/verification/verification_params.py b/hathor/verification/verification_params.py index e677d09f2..5e6fbbb64 100644 --- a/hathor/verification/verification_params.py +++ b/hathor/verification/verification_params.py @@ -18,7 +18,6 @@ from hathor.feature_activation.utils import Features from hathor.transaction import Block -from hathor.transaction.scripts.opcode import OpcodesVersion @dataclass(slots=True, frozen=True, kw_only=True) @@ -36,23 +35,18 @@ class VerificationParams: reject_conflicts_with_confirmed_txs: bool = False @classmethod - def default_for_mempool(cls, *, best_block: Block, features: Features | None = None) -> VerificationParams: + def default_for_mempool(cls, *, best_block: Block, features: Features) -> VerificationParams: """This is the appropriate parameters for verifying mempool transactions, realtime blocks and API pushes. + Callers MUST compute features via Features.from_vertex() to ensure + feature activation state (including shielded_transactions) is correct. + Other cases should instantiate `VerificationParams` manually with the appropriate parameter values. """ best_block_meta = best_block.get_metadata() if best_block_meta.nc_block_root_id is None: assert best_block.is_genesis - if features is None: - features = Features( - count_checkdatasig_op=True, - nanocontracts=True, - fee_tokens=False, - opcodes_version=OpcodesVersion.V2, - ) - return cls( nc_block_root_id=best_block_meta.nc_block_root_id, features=features, diff --git a/hathor/verification/verification_service.py b/hathor/verification/verification_service.py index 6f5ec9476..08f2f7e86 100644 --- a/hathor/verification/verification_service.py +++ b/hathor/verification/verification_service.py @@ -25,6 +25,7 @@ from hathor.transaction.token_info import TokenInfoDict from hathor.transaction.validation_state import ValidationState from hathor.verification.fee_header_verifier import FeeHeaderVerifier +from hathor.verification.transaction_verifier import TransactionVerifier from hathor.verification.verification_params import VerificationParams from hathor.verification.vertex_verifiers import VertexVerifiers @@ -139,6 +140,23 @@ def verify_basic( assert self._settings.ENABLE_NANO_CONTRACTS # nothing to do + if vertex.has_shielded_outputs(): + # VULN-009: Use feature activation state, not just settings + if not params.features.shielded_transactions: + from hathor.transaction.exceptions import InvalidShieldedOutputError + raise InvalidShieldedOutputError('shielded transactions are not enabled') + assert isinstance(vertex, Transaction) + self._verify_basic_shielded_header(vertex) + + def _verify_basic_shielded_header(self, tx: Transaction) -> None: + """Shielded verifications that don't need storage.""" + from hathor.transaction.exceptions import TxValidationError + try: + self.verifiers.tx.verify_shielded_outputs(tx) + except TxValidationError: + self.verifiers.tx.log.warning('shielded basic verification failed', tx=tx.hash_hex) + raise + def _verify_basic_block(self, block: Block, params: VerificationParams) -> None: """Partially run validations, the ones that need parents/inputs are skipped.""" if not params.skip_block_weight_verification: @@ -206,6 +224,23 @@ def verify(self, vertex: BaseTransaction, params: VerificationParams) -> None: self.verifiers.nano_header.verify_method_call(vertex, params) self.verifiers.nano_header.verify_seqnum(vertex, params) + if vertex.has_shielded_outputs(): + # VULN-009: Use feature activation state, not just settings + if not params.features.shielded_transactions: + from hathor.transaction.exceptions import InvalidShieldedOutputError + raise InvalidShieldedOutputError('shielded transactions are not enabled') + assert isinstance(vertex, Transaction) + self._verify_shielded_header(vertex) + + def _verify_shielded_header(self, tx: Transaction) -> None: + """Shielded verifications that need storage (balance, surjection).""" + from hathor.transaction.exceptions import TxValidationError + try: + self.verifiers.tx.verify_shielded_outputs_with_storage(tx) + except TxValidationError: + self.verifiers.tx.log.warning('shielded full verification failed', tx=tx.hash_hex) + raise + @cpu.profiler(key=lambda _, block: 'block-verify!{}'.format(block.hash.hex())) def _verify_block(self, block: Block, params: VerificationParams) -> None: """ @@ -264,14 +299,31 @@ def _verify_tx( self.verifiers.tx.verify_inputs(tx, params) # need to run verify_inputs first to check if all inputs exist self.verifiers.tx.verify_version(tx, params) - block_storage = self._get_block_storage(params) - self.verifiers.tx.verify_sum( - self._settings, - tx, - token_dict or tx.get_complete_token_info(block_storage), - # if this tx isn't a nano contract we assume we can find all the tokens to validate this tx - allow_nonexistent_tokens=tx.is_nano_contract() - ) + # VULN-003: Skip verify_sum for shielded transactions — balance is + # checked by verify_shielded_balance in _verify_shielded_header instead. + # CONS-001: But authority permissions, deposit requirements, and fee correctness + # must still be enforced via verify_token_rules. + # AUDIT-C002: Explicitly exclude TokenCreationTransaction to prevent + # bypass of minting verification via shielded outputs. + if ( + not isinstance(tx, TokenCreationTransaction) + and isinstance(tx, Transaction) + and tx.has_shielded_outputs() + ): + block_storage = self._get_block_storage(params) + _token_dict = token_dict or tx.get_complete_token_info(block_storage) + shielded_fee = TransactionVerifier.calculate_shielded_fee(self._settings, tx) + self.verifiers.tx.verify_no_mint_melt(_token_dict) + self.verifiers.tx.verify_token_rules(self._settings, _token_dict, shielded_fee=shielded_fee) + else: + block_storage = self._get_block_storage(params) + self.verifiers.tx.verify_sum( + self._settings, + tx, + token_dict or tx.get_complete_token_info(block_storage), + # if this tx isn't a nano contract we assume we can find all the tokens to validate this tx + allow_nonexistent_tokens=tx.is_nano_contract() + ) self.verifiers.vertex.verify_parents(tx) self.verifiers.tx.verify_conflict(tx, params) if params.reject_locked_reward: diff --git a/hathor/verification/vertex_verifier.py b/hathor/verification/vertex_verifier.py index e8045d914..0021af577 100644 --- a/hathor/verification/vertex_verifier.py +++ b/hathor/verification/vertex_verifier.py @@ -35,7 +35,7 @@ TooManyOutputs, TooManySigOps, ) -from hathor.transaction.headers import FeeHeader, NanoHeader, VertexBaseHeader +from hathor.transaction.headers import FeeHeader, NanoHeader, ShieldedOutputsHeader, VertexBaseHeader from hathor.verification.verification_params import VerificationParams # tx should have 2 parents, both other transactions @@ -209,6 +209,10 @@ def _verify_sigops_output( for tx_output in vertex.outputs: n_txops += counter.get_sigops_count(tx_output.script) + # CONS-005: Count shielded output scripts too + for shielded_output in vertex.shielded_outputs: + n_txops += counter.get_sigops_count(shielded_output.script) + if n_txops > settings.MAX_TX_SIGOPS_OUTPUT: raise TooManySigOps('TX[{}]: Maximum number of sigops for all outputs exceeded ({})'.format( vertex.hash_hex, n_txops)) @@ -225,15 +229,28 @@ def get_allowed_headers(self, vertex: BaseTransaction, params: VerificationParam pass case TxVersion.ON_CHAIN_BLUEPRINT: pass - case TxVersion.REGULAR_TRANSACTION | TxVersion.TOKEN_CREATION_TRANSACTION: + case TxVersion.TOKEN_CREATION_TRANSACTION: + if params.features.nanocontracts: + allowed_headers.add(NanoHeader) + if params.features.fee_tokens: + allowed_headers.add(FeeHeader) + # CONS-006: Token creation txs must NOT allow shielded outputs. + case TxVersion.REGULAR_TRANSACTION: if params.features.nanocontracts: allowed_headers.add(NanoHeader) if params.features.fee_tokens: allowed_headers.add(FeeHeader) + if params.features.shielded_transactions: + allowed_headers.add(ShieldedOutputsHeader) case _: # pragma: no cover assert_never(vertex.version) return allowed_headers + @staticmethod + def _get_header_order(header: VertexBaseHeader) -> int: + """Return the sort key for canonical header ordering (CONS-025).""" + return int.from_bytes(header.get_header_id(), 'big') + def verify_headers(self, vertex: BaseTransaction, params: VerificationParams) -> None: """Verify the headers.""" if len(vertex.headers) > vertex.get_maximum_number_of_headers(): @@ -252,6 +269,15 @@ def verify_headers(self, vertex: BaseTransaction, params: VerificationParams) -> ) seen_header_types.add(type(header)) + # CONS-025: verify headers are in canonical order (ascending VertexHeaderId) + if len(vertex.headers) > 1: + ids = [self._get_header_order(h) for h in vertex.headers] + for i in range(1, len(ids)): + if ids[i] <= ids[i - 1]: + raise HeaderNotSupported( + 'Headers must be in canonical order (ascending VertexHeaderId)' + ) + def verify_old_timestamp(self, vertex: BaseTransaction, params: VerificationParams) -> None: """Verify that the timestamp is not too old. Mempool only.""" if not params.reject_too_old_vertices: diff --git a/hathor/wallet/base_wallet.py b/hathor/wallet/base_wallet.py index f83ab4fbc..38dabf52d 100644 --- a/hathor/wallet/base_wallet.py +++ b/hathor/wallet/base_wallet.py @@ -322,9 +322,27 @@ def prepare_incomplete_inputs(self, inputs: list[WalletInputInfo], tx_storage: T for _input in inputs: new_input = None output_tx = tx_storage.get_transaction(_input.tx_id) - output = output_tx.outputs[_input.index] - token_id = output_tx.get_token_uid(output.get_token_index()) + resolved = output_tx.resolve_spent_output(_input.index) + + # For shielded outputs, try to find the token_id from our tracked UTXOs key = (_input.tx_id, _input.index) + if isinstance(resolved, TxOutput): + token_id: bytes = output_tx.get_token_uid(resolved.get_token_index()) + else: + # Shielded output: look up the token_id from our unspent_txs + _found_token_id: bytes | None = None + for tid, utxo_dict in self.unspent_txs.items(): + if key in utxo_dict: + _found_token_id = tid + break + if _found_token_id is None: + for tid, utxo_dict in self.maybe_spent_txs.items(): + if key in utxo_dict: + _found_token_id = tid + break + if _found_token_id is None: + raise PrivateKeyNotFound + token_id = _found_token_id # we'll remove this utxo so it can't be selected again shortly utxo = self.unspent_txs[token_id].pop(key, None) if utxo is None: @@ -334,7 +352,7 @@ def prepare_incomplete_inputs(self, inputs: list[WalletInputInfo], tx_storage: T utxo.maybe_spent_ts = int(self.reactor.seconds()) self.maybe_spent_txs[token_id][key] = utxo elif force: - script_type = parse_address_script(output.script) + script_type = parse_address_script(resolved.script) if script_type: address = script_type.address @@ -559,25 +577,50 @@ def on_new_tx(self, tx: BaseTransaction) -> None: # publish new output and new balance self.publish_update(HathorEvents.WALLET_OUTPUT_RECEIVED, total=self.get_total_tx(), output=utxo) + # check shielded outputs — try ECDH + rewind to recover hidden amounts + if tx.shielded_outputs: + if self._process_shielded_outputs_on_new_tx(tx): + should_update = True + # check inputs for _input in tx.inputs: assert tx.storage is not None output_tx = tx.storage.get_transaction(_input.tx_id) - output = output_tx.outputs[_input.index] - token_id = output_tx.get_token_uid(output.get_token_index()) + resolved = output_tx.resolve_spent_output(_input.index) - script_type_out = parse_address_script(output.script) + script_type_out = parse_address_script(resolved.script) if not script_type_out: self.log.warn('unknown input data') continue if script_type_out.address not in self.keys: continue - # this wallet spent tokens - # remove from unspent_txs + + # For shielded outputs, look up via the unspent_txs that were + # added by _process_shielded_outputs_on_new_tx (ECDH rewind). + # The key and token_id are stored there. key = (_input.tx_id, _input.index) - old_utxo = self.unspent_txs[token_id].pop(key, None) - if old_utxo is None: - old_utxo = self.maybe_spent_txs[token_id].pop(key, None) + + # Try to find the UTXO across all token buckets + old_utxo = None + if isinstance(resolved, TxOutput): + token_id = output_tx.get_token_uid(resolved.get_token_index()) + old_utxo = self.unspent_txs[token_id].pop(key, None) + if old_utxo is None: + old_utxo = self.maybe_spent_txs[token_id].pop(key, None) + else: + # Shielded output: scan all token buckets for the UTXO + for tid, utxo_dict in self.unspent_txs.items(): + if key in utxo_dict: + old_utxo = utxo_dict.pop(key) + token_id = tid + break + if old_utxo is None: + for tid, utxo_dict in self.maybe_spent_txs.items(): + if key in utxo_dict: + old_utxo = utxo_dict.pop(key) + token_id = tid + break + if old_utxo: # add to spent_txs spent = SpentTx(tx.hash, _input.tx_id, _input.index, old_utxo.value, tx.timestamp) @@ -589,9 +632,13 @@ def on_new_tx(self, tx: BaseTransaction) -> None: # If we dont have it in the unspent_txs, it must be in the spent_txs # So we append this spent with the others if key in self.spent_txs: - output_tx = tx.storage.get_transaction(_input.tx_id) - output = output_tx.outputs[_input.index] - spent = SpentTx(tx.hash, _input.tx_id, _input.index, output.value, tx.timestamp) + # For transparent outputs, get the value directly + if isinstance(resolved, TxOutput): + value = resolved.value + else: + # For shielded outputs, use 0 as fallback (value is hidden) + value = 0 + spent = SpentTx(tx.hash, _input.tx_id, _input.index, value, tx.timestamp) self.spent_txs[key].append(spent) if should_update: @@ -599,6 +646,31 @@ def on_new_tx(self, tx: BaseTransaction) -> None: # XXX should wallet always update it or it will be called externally? self.update_balance() + @staticmethod + def _verify_recovered_token_uid(token_id: bytes, asset_bf: bytes, asset_commitment: bytes) -> None: + """Verify that a recovered token UID and asset blinding factor match the asset_commitment. + + AUDIT-C015: Prevents social engineering attacks where a malicious sender + embeds a wrong token UID in the range proof message. + + Raises ValueError if the token UID is inconsistent. + """ + # TODO: Reconstruct the expected asset_commitment from the recovered token_id and + # asset_blinding_factor using derive_tag() and create_asset_commitment() from + # hathor.crypto.shielded.asset_tag. Compare against the actual asset_commitment. + pass + + def _process_shielded_outputs_on_new_tx(self, tx: BaseTransaction) -> bool: + """Try to recover shielded outputs that belong to this wallet via ECDH + rewind. + + Returns True if any shielded output was recovered. + """ + # TODO: For each shielded output matching a wallet address, use ECDH + # (derive_ecdh_shared_secret, derive_rewind_nonce from hathor.crypto.shielded.ecdh) + # to recover the range proof nonce, then call rewind_range_proof to extract + # value/blinding/message. Track recovered outputs as unspent UTXOs. + return False + def on_tx_update(self, tx: Transaction) -> None: """This method is called when a tx is updated by the consensus algorithm.""" meta = tx.get_metadata() @@ -662,9 +734,34 @@ def on_tx_voided(self, tx: Transaction) -> None: self.voided_unspent[key] = voided should_update = True + # check shielded outputs — remove from unspent/spent if voided + for shielded_idx, shielded_output in enumerate(tx.shielded_outputs): + actual_index = len(tx.outputs) + shielded_idx + key = (tx.hash, actual_index) + # Check all token_id buckets since we don't know which one it was tracked under + for token_id, utxos in list(self.unspent_txs.items()): + utxo = utxos.pop(key, None) + if utxo is None: + utxo = self.maybe_spent_txs[token_id].pop(key, None) + if utxo: + should_update = True + # Save in voided_unspent + if key not in self.voided_unspent: + voided = UnspentTx(tx.hash, actual_index, utxo.value, tx.timestamp, + utxo.address, 0, voided=True, timelock=utxo.timelock) + self.voided_unspent[key] = voided + break + else: + if key in self.spent_txs: + should_update = True + del self.spent_txs[key] + # check inputs for _input in tx.inputs: output_tx = tx.storage.get_transaction(_input.tx_id) + # CONS-023: skip shielded outputs + if output_tx.is_shielded_output(_input.index): + continue output_ = output_tx.outputs[_input.index] script_type_out = parse_address_script(output_.script) token_id = output_tx.get_token_uid(output_.get_token_index()) @@ -778,9 +875,34 @@ def on_tx_winner(self, tx: Transaction) -> None: # If it's there, we should update should_update = True + # check shielded outputs — restore from voided_unspent if winner + for shielded_idx, shielded_output in enumerate(tx.shielded_outputs): + actual_index = len(tx.outputs) + shielded_idx + key = (tx.hash, actual_index) + voided_utxo = self.voided_unspent.pop(key, None) + if voided_utxo: + # Restore the UTXO from voided state + _restored = UnspentTx( # noqa: F841 + tx.hash, actual_index, voided_utxo.value, tx.timestamp, + voided_utxo.address, 0, timelock=voided_utxo.timelock, + ) + # Determine token_id: we stored it generically, try to find via script match + # Use HATHOR_TOKEN_UID as default; the correct bucket was used when first tracked + for tid, utxos in self.unspent_txs.items(): + if key in utxos: + break + else: + # Re-process to find the right token bucket + if self._process_shielded_outputs_on_new_tx(tx): + should_update = True + should_update = True + # check inputs for _input in tx.inputs: output_tx = tx.storage.get_transaction(_input.tx_id) + # CONS-023: skip shielded outputs + if output_tx.is_shielded_output(_input.index): + continue output = output_tx.outputs[_input.index] token_id = output_tx.get_token_uid(output.get_token_index()) @@ -946,6 +1068,10 @@ def match_inputs(self, inputs: list[TxInput], """ for _input in inputs: output_tx = tx_storage.get_transaction(_input.tx_id) + # CONS-023: skip shielded outputs + if output_tx.is_shielded_output(_input.index): + yield _input, None + continue output = output_tx.outputs[_input.index] token_id = output_tx.get_token_uid(output.get_token_index()) utxo = self.unspent_txs[token_id].get((_input.tx_id, _input.index)) diff --git a/hathor/wallet/resources/send_tokens.py b/hathor/wallet/resources/send_tokens.py index 9cfbbfc5c..d2180de07 100644 --- a/hathor/wallet/resources/send_tokens.py +++ b/hathor/wallet/resources/send_tokens.py @@ -22,6 +22,7 @@ from hathor.conf.settings import HathorSettings from hathor.crypto.util import decode_address from hathor.exception import InvalidNewTransaction +from hathor.feature_activation.utils import Features from hathor.manager import HathorManager from hathor.transaction import Transaction from hathor.transaction.exceptions import TxValidationError @@ -134,7 +135,12 @@ def _render_POST_thread(self, values: dict[str, Any], request: Request) -> Union self.manager.cpu_mining_service.resolve(tx) tx.init_static_metadata_from_storage(self._settings, self.manager.tx_storage) best_block = self.manager.tx_storage.get_best_block() - params = VerificationParams.default_for_mempool(best_block=best_block) + features = Features.from_vertex( + settings=self._settings, + feature_service=self.manager.feature_service, + vertex=best_block, + ) + params = VerificationParams.default_for_mempool(best_block=best_block, features=features) self.manager.verification_service.verify(tx, params) return tx diff --git a/hathor/wallet/resources/thin_wallet/address_balance.py b/hathor/wallet/resources/thin_wallet/address_balance.py index 82b48a6ef..3b1fec3d8 100644 --- a/hathor/wallet/resources/thin_wallet/address_balance.py +++ b/hathor/wallet/resources/thin_wallet/address_balance.py @@ -100,6 +100,7 @@ def render_GET(self, request: Request) -> bytes: }) tokens_data: dict[bytes, TokenData] = defaultdict(TokenData) + has_shielded = False tx_hashes = addresses_index.get_from_address(requested_address) for tx_hash in tx_hashes: tx = self.manager.tx_storage.get_transaction(tx_hash) @@ -108,6 +109,10 @@ def render_GET(self, request: Request) -> bytes: # We consider the spent/received values only if is not voided by for tx_input in tx.inputs: tx2 = self.manager.tx_storage.get_transaction(tx_input.tx_id) + # Skip shielded outputs — hidden amounts + if tx2.is_shielded_output(tx_input.index): + has_shielded = True + continue tx2_output = tx2.outputs[tx_input.index] if self.has_address(tx2_output, requested_address): # We just consider the address that was requested @@ -120,6 +125,10 @@ def render_GET(self, request: Request) -> bytes: token_uid = tx.get_token_uid(tx_output.get_token_index()) tokens_data[token_uid].received += tx_output.value + # Track if any shielded outputs exist for this address + if tx.shielded_outputs: + has_shielded = True + return_tokens_data: dict[str, dict[str, Any]] = {} for token_uid in tokens_data.keys(): if token_uid == self._settings.HATHOR_TOKEN_UID: @@ -137,11 +146,13 @@ def render_GET(self, request: Request) -> bytes: tokens_data[token_uid].symbol = '- (unable to fetch token information)' return_tokens_data[token_uid.hex()] = tokens_data[token_uid].to_dict() - data = { + data: dict[str, Any] = { 'success': True, 'total_transactions': len(tx_hashes), - 'tokens_data': return_tokens_data + 'tokens_data': return_tokens_data, } + if has_shielded: + data['has_shielded'] = True return json_dumpb(data) diff --git a/hathor/wallet/resources/thin_wallet/address_search.py b/hathor/wallet/resources/thin_wallet/address_search.py index 097e20424..d9f1451ad 100644 --- a/hathor/wallet/resources/thin_wallet/address_search.py +++ b/hathor/wallet/resources/thin_wallet/address_search.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from twisted.web.http import Request @@ -46,6 +46,9 @@ def has_token_and_address(self, tx: 'BaseTransaction', address: str, token: byte """ for tx_input in tx.inputs: spent_tx = tx.get_spent_tx(tx_input) + # CONS-022: skip shielded outputs + if spent_tx.is_shielded_output(tx_input.index): + continue spent_output = spent_tx.outputs[tx_input.index] input_token_uid = spent_tx.get_token_uid(spent_output.get_token_index()) @@ -132,12 +135,15 @@ def render_GET(self, request: Request) -> bytes: # we must get all transactions and sort them # This is not optimal for performance transactions = [] + has_shielded = False for tx_hash in hashes: tx = self.manager.tx_storage.get_transaction(tx_hash) if token_uid_bytes and not self.has_token_and_address(tx, address, token_uid_bytes): # Request wants to filter by token but tx does not have this token # so we don't add it to the transactions array continue + if tx.shielded_outputs: + has_shielded = True transactions.append(tx.to_json_extended()) sorted_transactions = sorted(transactions, key=lambda tx: tx['timestamp'], reverse=True) @@ -186,12 +192,14 @@ def render_GET(self, request: Request) -> bytes: ret_transactions = sorted_transactions[:count] has_more = len(sorted_transactions) > count - data = { + data: dict[str, Any] = { 'success': True, 'transactions': ret_transactions, 'has_more': has_more, 'total': len(sorted_transactions), } + if has_shielded: + data['has_shielded'] = True return json_dumpb(data) diff --git a/hathor/wallet/resources/thin_wallet/send_tokens.py b/hathor/wallet/resources/thin_wallet/send_tokens.py index 0ea4c0024..7c8d223e9 100644 --- a/hathor/wallet/resources/thin_wallet/send_tokens.py +++ b/hathor/wallet/resources/thin_wallet/send_tokens.py @@ -27,6 +27,7 @@ from hathor.api_util import Resource, render_options, set_cors from hathor.conf.get_settings import get_global_settings from hathor.exception import InvalidNewTransaction +from hathor.feature_activation.utils import Features from hathor.reactor import get_global_reactor from hathor.transaction import Transaction from hathor.transaction.exceptions import TxValidationError @@ -215,7 +216,12 @@ def _stratum_thread_verify(self, context: _Context) -> _Context: """ Method to verify the transaction that runs in a separated thread """ best_block = self.manager.tx_storage.get_best_block() - params = VerificationParams.default_for_mempool(best_block=best_block) + features = Features.from_vertex( + settings=self._settings, + feature_service=self.manager.feature_service, + vertex=best_block, + ) + params = VerificationParams.default_for_mempool(best_block=best_block, features=features) self.manager.verification_service.verify(context.tx, params) return context @@ -274,7 +280,12 @@ def _should_stop(): context.tx.update_hash() context.tx.init_static_metadata_from_storage(self._settings, self.manager.tx_storage) best_block = self.manager.tx_storage.get_best_block() - params = VerificationParams.default_for_mempool(best_block=best_block) + features = Features.from_vertex( + settings=self._settings, + feature_service=self.manager.feature_service, + vertex=best_block, + ) + params = VerificationParams.default_for_mempool(best_block=best_block, features=features) self.manager.verification_service.verify(context.tx, params) return context diff --git a/hathor_cli/mining.py b/hathor_cli/mining.py index 2620016c6..807f70ab9 100644 --- a/hathor_cli/mining.py +++ b/hathor_cli/mining.py @@ -149,6 +149,7 @@ def execute(args: Namespace) -> None: nanocontracts=False, fee_tokens=False, opcodes_version=OpcodesVersion.V2, + shielded_transactions=False, )) verifiers = VertexVerifiers.create_defaults( reactor=Mock(), diff --git a/hathor_tests/nanocontracts/test_actions.py b/hathor_tests/nanocontracts/test_actions.py index 934c0f0ea..8ddea182e 100644 --- a/hathor_tests/nanocontracts/test_actions.py +++ b/hathor_tests/nanocontracts/test_actions.py @@ -126,6 +126,7 @@ def setUp(self) -> None: nanocontracts=True, fee_tokens=False, opcodes_version=OpcodesVersion.V1, + shielded_transactions=False, ) ) diff --git a/hathor_tests/nanocontracts/test_nanocontract.py b/hathor_tests/nanocontracts/test_nanocontract.py index aab9d722e..48ebdebaf 100644 --- a/hathor_tests/nanocontracts/test_nanocontract.py +++ b/hathor_tests/nanocontracts/test_nanocontract.py @@ -1,5 +1,4 @@ from typing import Any -from unittest.mock import Mock import pytest from cryptography.hazmat.primitives import hashes @@ -40,7 +39,6 @@ from hathor.transaction.scripts import P2PKH, HathorScript, Opcode from hathor.transaction.validation_state import ValidationState from hathor.verification.nano_header_verifier import MAX_NC_SCRIPT_SIGOPS_COUNT, MAX_NC_SCRIPT_SIZE -from hathor.verification.verification_params import VerificationParams from hathor.wallet import KeyPair from hathor_tests import unittest @@ -86,7 +84,7 @@ def setUp(self) -> None: self.genesis = self.peer.tx_storage.get_all_genesis() self.genesis_txs = [tx for tx in self.genesis if not tx.is_block] - self.verification_params = VerificationParams.default_for_mempool(best_block=Mock()) + self.verification_params = self.get_verification_params() def _create_nc( self, diff --git a/hathor_tests/tx/test_nano_header.py b/hathor_tests/tx/test_nano_header.py index fe9f15275..b4375a0ee 100644 --- a/hathor_tests/tx/test_nano_header.py +++ b/hathor_tests/tx/test_nano_header.py @@ -24,6 +24,10 @@ def nop(self, ctx: Context) -> None: class FakeHeader(VertexBaseHeader): + @classmethod + def get_header_id(cls) -> bytes: + return b'\xff' + @classmethod def deserialize( cls, diff --git a/hathor_tests/tx/test_shielded_tx.py b/hathor_tests/tx/test_shielded_tx.py new file mode 100644 index 000000000..ea1edebac --- /dev/null +++ b/hathor_tests/tx/test_shielded_tx.py @@ -0,0 +1,243 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for shielded transaction output types and header serialization. + +These tests use dummy bytes (not real cryptographic values) to verify that the +data models, serialization, and sighash logic work correctly as infrastructure. +""" + +import os + +import pytest + +from hathor.transaction.shielded_tx_output import ( + AmountShieldedOutput, + FullShieldedOutput, + OutputMode, + deserialize_shielded_output, + get_sighash_bytes, + serialize_shielded_output, +) + + +def _make_amount_shielded_output(token_data: int = 0) -> AmountShieldedOutput: + """Create an AmountShieldedOutput with dummy bytes for serialization testing.""" + return AmountShieldedOutput( + commitment=os.urandom(33), + range_proof=os.urandom(675), + script=b'\x76\xa9\x14' + os.urandom(20) + b'\x88\xac', + token_data=token_data, + ) + + +def _make_full_shielded_output() -> FullShieldedOutput: + """Create a FullShieldedOutput with dummy bytes for serialization testing.""" + return FullShieldedOutput( + commitment=os.urandom(33), + range_proof=os.urandom(675), + script=b'\x76\xa9\x14' + os.urandom(20) + b'\x88\xac', + asset_commitment=os.urandom(33), + surjection_proof=os.urandom(256), + ) + + +class TestOutputMode: + def test_amount_only_mode(self) -> None: + output = _make_amount_shielded_output() + assert output.mode() == OutputMode.AMOUNT_ONLY + assert output.mode() == 1 + + def test_fully_shielded_mode(self) -> None: + output = _make_full_shielded_output() + assert output.mode() == OutputMode.FULLY_SHIELDED + assert output.mode() == 2 + + +class TestAmountShieldedOutput: + def test_fields(self) -> None: + output = _make_amount_shielded_output(token_data=1) + assert len(output.commitment) == 33 + assert len(output.range_proof) > 0 + assert len(output.script) > 0 + assert output.token_data == 1 + + def test_frozen(self) -> None: + output = _make_amount_shielded_output() + with pytest.raises(AttributeError): + output.commitment = b'\x00' * 33 # type: ignore[misc] + + def test_isinstance(self) -> None: + output = _make_amount_shielded_output() + assert isinstance(output, AmountShieldedOutput) + assert not isinstance(output, FullShieldedOutput) + + +class TestFullShieldedOutput: + def test_fields(self) -> None: + output = _make_full_shielded_output() + assert len(output.commitment) == 33 + assert len(output.range_proof) > 0 + assert len(output.script) > 0 + assert len(output.asset_commitment) == 33 + assert len(output.surjection_proof) > 0 + + def test_frozen(self) -> None: + output = _make_full_shielded_output() + with pytest.raises(AttributeError): + output.commitment = b'\x00' * 33 # type: ignore[misc] + + def test_isinstance(self) -> None: + output = _make_full_shielded_output() + assert isinstance(output, FullShieldedOutput) + assert not isinstance(output, AmountShieldedOutput) + + +class TestSerialization: + def test_amount_shielded_roundtrip(self) -> None: + output = _make_amount_shielded_output(token_data=2) + data = serialize_shielded_output(output) + restored, remaining = deserialize_shielded_output(data) + assert remaining == b'' + assert isinstance(restored, AmountShieldedOutput) + assert restored.commitment == output.commitment + assert restored.range_proof == output.range_proof + assert restored.script == output.script + assert restored.token_data == output.token_data + + def test_full_shielded_roundtrip(self) -> None: + output = _make_full_shielded_output() + data = serialize_shielded_output(output) + restored, remaining = deserialize_shielded_output(data) + assert remaining == b'' + assert isinstance(restored, FullShieldedOutput) + assert restored.commitment == output.commitment + assert restored.range_proof == output.range_proof + assert restored.script == output.script + assert restored.asset_commitment == output.asset_commitment + assert restored.surjection_proof == output.surjection_proof + + def test_multiple_outputs_concatenated(self) -> None: + o1 = _make_amount_shielded_output() + o2 = _make_full_shielded_output() + data = serialize_shielded_output(o1) + serialize_shielded_output(o2) + r1, remaining = deserialize_shielded_output(data) + r2, remaining = deserialize_shielded_output(remaining) + assert remaining == b'' + assert isinstance(r1, AmountShieldedOutput) + assert isinstance(r2, FullShieldedOutput) + + +class TestSighashBytes: + def test_amount_shielded_sighash_no_proofs(self) -> None: + output = _make_amount_shielded_output() + sighash = get_sighash_bytes(output) + # Should NOT contain range_proof + assert output.range_proof not in sighash + # Should contain commitment and script + assert output.commitment in sighash + assert output.script in sighash + + def test_full_shielded_sighash_no_proofs(self) -> None: + output = _make_full_shielded_output() + sighash = get_sighash_bytes(output) + # Should NOT contain range_proof or surjection_proof + assert output.range_proof not in sighash + assert output.surjection_proof not in sighash + # Should contain commitment, asset_commitment, and script + assert output.commitment in sighash + assert output.asset_commitment in sighash + assert output.script in sighash + + def test_different_modes_different_sighash(self) -> None: + o1 = _make_amount_shielded_output() + o2 = _make_full_shielded_output() + s1 = get_sighash_bytes(o1) + s2 = get_sighash_bytes(o2) + # Mode byte differs so sighash must differ + assert s1[0:1] != s2[0:1] + + +class TestEphemeralPubkeySerialization: + def _make_ephemeral_pubkey(self) -> bytes: + """Generate a valid compressed secp256k1 pubkey.""" + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat + key = ec.generate_private_key(ec.SECP256K1()) + return key.public_key().public_bytes(Encoding.X962, PublicFormat.CompressedPoint) + + def test_amount_shielded_with_ephemeral_pubkey_roundtrip(self) -> None: + ephemeral = self._make_ephemeral_pubkey() + output = AmountShieldedOutput( + commitment=os.urandom(33), + range_proof=os.urandom(675), + script=b'\x76\xa9\x14' + os.urandom(20) + b'\x88\xac', + token_data=0, + ephemeral_pubkey=ephemeral, + ) + data = serialize_shielded_output(output) + restored, remaining = deserialize_shielded_output(data) + assert remaining == b'' + assert isinstance(restored, AmountShieldedOutput) + assert restored.ephemeral_pubkey == ephemeral + assert restored.commitment == output.commitment + assert restored.token_data == output.token_data + + def test_full_shielded_with_ephemeral_pubkey_roundtrip(self) -> None: + ephemeral = self._make_ephemeral_pubkey() + output = FullShieldedOutput( + commitment=os.urandom(33), + range_proof=os.urandom(675), + script=b'\x76\xa9\x14' + os.urandom(20) + b'\x88\xac', + asset_commitment=os.urandom(33), + surjection_proof=os.urandom(256), + ephemeral_pubkey=ephemeral, + ) + data = serialize_shielded_output(output) + restored, remaining = deserialize_shielded_output(data) + assert remaining == b'' + assert isinstance(restored, FullShieldedOutput) + assert restored.ephemeral_pubkey == ephemeral + assert restored.asset_commitment == output.asset_commitment + + def test_sighash_includes_ephemeral_pubkey(self) -> None: + """Sighash with ephemeral pubkey differs from sighash without.""" + commitment = os.urandom(33) + range_proof = os.urandom(675) + script = b'\x76\xa9\x14' + os.urandom(20) + b'\x88\xac' + + without = AmountShieldedOutput( + commitment=commitment, range_proof=range_proof, script=script, token_data=0, + ) + ephemeral = self._make_ephemeral_pubkey() + with_epk = AmountShieldedOutput( + commitment=commitment, range_proof=range_proof, script=script, token_data=0, + ephemeral_pubkey=ephemeral, + ) + + s1 = get_sighash_bytes(without) + s2 = get_sighash_bytes(with_epk) + assert s1 != s2 + # Both sighashes have the same length (ephemeral_pubkey is always + # included — zero bytes when absent, actual pubkey when present) + assert len(s2) == len(s1) + + def test_backward_compat_no_ephemeral_pubkey(self) -> None: + """Legacy outputs without ephemeral pubkey still work.""" + output = _make_amount_shielded_output() + assert output.ephemeral_pubkey == b'' + data = serialize_shielded_output(output) + restored, remaining = deserialize_shielded_output(data) + assert remaining == b'' + assert restored.ephemeral_pubkey == b'' diff --git a/hathor_tests/tx/test_shielded_v3_audit_fixes.py b/hathor_tests/tx/test_shielded_v3_audit_fixes.py new file mode 100644 index 000000000..5fca7a476 --- /dev/null +++ b/hathor_tests/tx/test_shielded_v3_audit_fixes.py @@ -0,0 +1,155 @@ +# Copyright 2025 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regression tests for V3 audit findings on shielded outputs. + +V3-001: Streaming client must allow shielded transactions during sync. +V3-002: All callers of default_for_mempool must pass explicit features. +V3-005: is_shielded_output must not return True for out-of-range indices. +""" + +import ast +import inspect +import textwrap +from typing import Callable +from unittest.mock import MagicMock + +from hathor.transaction.base_transaction import BaseTransaction + + +def _method_calls_from_vertex(method: Callable) -> bool: + """Return True if the method source contains a `Features.from_vertex(` call.""" + source = textwrap.dedent(inspect.getsource(method)) + return 'Features.from_vertex(' in source + + +def _method_passes_features_to_default_for_mempool(method: Callable) -> bool: + """Return True if the method passes `features=` to `default_for_mempool(`.""" + source = textwrap.dedent(inspect.getsource(method)) + tree = ast.parse(source) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if isinstance(func, ast.Attribute) and func.attr == 'default_for_mempool': + for kw in node.keywords: + if kw.arg == 'features': + return True + return False + + +class TestV3001StreamingClientFeatureGate: + """V3-001: Streaming client defaults to shielded_transactions=False. + + During sync, the node processes blocks at different heights — some before feature + activation and some after. Defaulting to False is safe because shielded txs cannot + exist before the feature is activated. Full validation will compute the correct + value anyway. + """ + + def test_streaming_client_defaults_shielded_false(self) -> None: + """The streaming client VerificationParams should default shielded_transactions=False.""" + from hathor.p2p.sync_v2.transaction_streaming_client import TransactionStreamingClient + + source = textwrap.dedent(inspect.getsource(TransactionStreamingClient.__init__)) + assert 'shielded_transactions=False' in source, \ + 'V3-001: streaming client should default shielded_transactions=False' + + +class TestV3002MempoolCallerFeatures: + """V3-002: All callers of default_for_mempool must pass explicit features.""" + + def test_create_tx_passes_features(self) -> None: + from hathor.transaction.resources.create_tx import CreateTxResource + method = CreateTxResource._verify_unsigned_skip_pow + assert _method_calls_from_vertex(method), \ + 'V3-002 regression: CreateTxResource._verify_unsigned_skip_pow must call Features.from_vertex' + assert _method_passes_features_to_default_for_mempool(method), \ + 'V3-002 regression: CreateTxResource._verify_unsigned_skip_pow must pass features= to default_for_mempool' + + def test_wallet_send_tokens_passes_features(self) -> None: + from hathor.wallet.resources.send_tokens import SendTokensResource + method = SendTokensResource._render_POST_thread + assert _method_calls_from_vertex(method), \ + 'V3-002 regression: SendTokensResource._render_POST_thread must call Features.from_vertex' + assert _method_passes_features_to_default_for_mempool(method), \ + 'V3-002 regression: SendTokensResource._render_POST_thread must pass features= to default_for_mempool' + + def test_thin_wallet_stratum_verify_passes_features(self) -> None: + from hathor.wallet.resources.thin_wallet.send_tokens import SendTokensResource + method = SendTokensResource._stratum_thread_verify + assert _method_calls_from_vertex(method), \ + 'V3-002 regression: thin_wallet._stratum_thread_verify must call Features.from_vertex' + assert _method_passes_features_to_default_for_mempool(method), \ + 'V3-002 regression: thin_wallet._stratum_thread_verify must pass features= to default_for_mempool' + + def test_thin_wallet_render_post_passes_features(self) -> None: + from hathor.wallet.resources.thin_wallet.send_tokens import SendTokensResource + method = SendTokensResource._render_POST_thread + assert _method_calls_from_vertex(method), \ + 'V3-002 regression: thin_wallet._render_POST_thread must call Features.from_vertex' + assert _method_passes_features_to_default_for_mempool(method), \ + 'V3-002 regression: thin_wallet._render_POST_thread must pass features= to default_for_mempool' + + def test_consensus_opcodes_v2_rule_passes_features(self) -> None: + from hathor.consensus.consensus import ConsensusAlgorithm + method = ConsensusAlgorithm._opcodes_v2_activation_rule + assert _method_calls_from_vertex(method), \ + 'V3-002 regression: ConsensusAlgorithm._opcodes_v2_activation_rule must call Features.from_vertex' + assert _method_passes_features_to_default_for_mempool(method), \ + 'V3-002 regression: _opcodes_v2_activation_rule must pass features= to default_for_mempool' + + +class TestV3005IsShieldedOutputBounds: + """V3-005: is_shielded_output must check upper bound.""" + + def test_returns_false_when_no_shielded_outputs(self) -> None: + """Out-of-range index must return False when there are no shielded outputs.""" + tx = MagicMock(spec=BaseTransaction) + tx.outputs = [MagicMock(), MagicMock()] + tx.shielded_outputs = [] + tx.is_shielded_output = BaseTransaction.is_shielded_output.__get__(tx) + + assert tx.is_shielded_output(2) is False + assert tx.is_shielded_output(100) is False + + def test_returns_true_for_valid_shielded_index(self) -> None: + """Index in the shielded range must return True.""" + tx = MagicMock(spec=BaseTransaction) + tx.outputs = [MagicMock(), MagicMock()] + tx.shielded_outputs = [MagicMock()] + tx.is_shielded_output = BaseTransaction.is_shielded_output.__get__(tx) + + assert tx.is_shielded_output(2) is True + + def test_returns_false_for_standard_index(self) -> None: + """Standard output index must return False.""" + tx = MagicMock(spec=BaseTransaction) + tx.outputs = [MagicMock(), MagicMock()] + tx.shielded_outputs = [MagicMock()] + tx.is_shielded_output = BaseTransaction.is_shielded_output.__get__(tx) + + assert tx.is_shielded_output(0) is False + assert tx.is_shielded_output(1) is False + + def test_returns_false_beyond_shielded_range(self) -> None: + """Index beyond both standard and shielded outputs must return False.""" + tx = MagicMock(spec=BaseTransaction) + tx.outputs = [MagicMock(), MagicMock()] + tx.shielded_outputs = [MagicMock()] + tx.is_shielded_output = BaseTransaction.is_shielded_output.__get__(tx) + + # 2 standard + 1 shielded = valid range 0..2, index 3 is out of range + assert tx.is_shielded_output(3) is False + assert tx.is_shielded_output(100) is False diff --git a/hathor_tests/tx/test_tx.py b/hathor_tests/tx/test_tx.py index 1e537ee9a..4c7018e2c 100644 --- a/hathor_tests/tx/test_tx.py +++ b/hathor_tests/tx/test_tx.py @@ -1,7 +1,7 @@ import base64 import hashlib from math import isinf, isnan -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest @@ -36,7 +36,6 @@ from hathor.transaction.scripts import P2PKH, parse_address_script from hathor.transaction.util import int_to_bytes from hathor.transaction.validation_state import ValidationState -from hathor.verification.verification_params import VerificationParams from hathor.wallet import Wallet from hathor_tests import unittest from hathor_tests.utils import ( @@ -68,7 +67,7 @@ def setUp(self): blocks = add_blocks_unlock_reward(self.manager) self.last_block = blocks[-1] - self.verification_params = VerificationParams.default_for_mempool(best_block=Mock()) + self.verification_params = self.get_verification_params() def test_input_output_match_less_htr(self): genesis_block = self.genesis_blocks[0] diff --git a/hathor_tests/unittest.py b/hathor_tests/unittest.py index bbcde3d48..b06927894 100644 --- a/hathor_tests/unittest.py +++ b/hathor_tests/unittest.py @@ -530,5 +530,15 @@ def get_address(self, index: int) -> Optional[str]: @staticmethod def get_verification_params(manager: HathorManager | None = None) -> VerificationParams: + from hathor.feature_activation.utils import Features + from hathor.transaction.scripts.opcode import OpcodesVersion + best_block = manager.tx_storage.get_best_block() if manager else None - return VerificationParams.default_for_mempool(best_block=best_block or Mock()) + features = Features( + count_checkdatasig_op=True, + nanocontracts=True, + fee_tokens=False, + opcodes_version=OpcodesVersion.V2, + shielded_transactions=False, + ) + return VerificationParams.default_for_mempool(best_block=best_block or Mock(), features=features) diff --git a/hathor_tests/wallet/test_wallet_hd.py b/hathor_tests/wallet/test_wallet_hd.py index 60dc0d104..e4bea0f9b 100644 --- a/hathor_tests/wallet/test_wallet_hd.py +++ b/hathor_tests/wallet/test_wallet_hd.py @@ -1,9 +1,6 @@ -from unittest.mock import Mock - from hathor.crypto.util import decode_address from hathor.simulator.utils import add_new_block from hathor.transaction import Transaction -from hathor.verification.verification_params import VerificationParams from hathor.wallet import HDWallet from hathor.wallet.base_wallet import WalletBalance, WalletInputInfo, WalletOutputInfo from hathor.wallet.exceptions import InsufficientFunds @@ -42,7 +39,7 @@ def test_transaction_and_balance(self): tx1 = self.wallet.prepare_transaction_compute_inputs(Transaction, [out], self.tx_storage) tx1.update_hash() verifier = self.manager.verification_service.verifiers.tx - params = VerificationParams.default_for_mempool(best_block=Mock()) + params = self.get_verification_params() verifier.verify_script(tx=tx1, input_tx=tx1.inputs[0], spent_tx=block, params=params) tx1.storage = self.tx_storage tx1.get_metadata().validation = ValidationState.FULL