From c8791241e190b884e1ab008ede0d6455f2c708b2 Mon Sep 17 00:00:00 2001 From: Daniel Wang <99078276+dantaik@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:25:43 +0800 Subject: [PATCH] fix(protocol): fix bridge bug caused by incorrect check of `receivedAt` (by OZ) (#16545) --- packages/protocol/contracts/bridge/Bridge.sol | 50 ++++++++++--- .../protocol/contracts/bridge/IBridge.sol | 3 +- packages/protocol/test/bridge/Bridge.t.sol | 75 ++++++++++++++----- 3 files changed, 98 insertions(+), 30 deletions(-) diff --git a/packages/protocol/contracts/bridge/Bridge.sol b/packages/protocol/contracts/bridge/Bridge.sol index da7b4a6969..f18576ac08 100644 --- a/packages/protocol/contracts/bridge/Bridge.sol +++ b/packages/protocol/contracts/bridge/Bridge.sol @@ -53,7 +53,10 @@ contract Bridge is EssentialContract, IBridge { error B_INVALID_STATUS(); error B_INVALID_USER(); error B_INVALID_VALUE(); + error B_MESSAGE_NOT_PROVEN(); error B_MESSAGE_NOT_SENT(); + error B_MESSAGE_NOT_SUSPENDED(); + error B_MESSAGE_SUSPENDED(); error B_NON_RETRIABLE(); error B_NOT_FAILED(); error B_NOT_RECEIVED(); @@ -86,11 +89,26 @@ contract Bridge is EssentialContract, IBridge { external onlyFromOwnerOrNamed("bridge_watchdog") { - uint64 _timestamp = _suspend ? type(uint64).max : uint64(block.timestamp); for (uint256 i; i < _msgHashes.length; ++i) { bytes32 msgHash = _msgHashes[i]; - proofReceipt[msgHash].receivedAt = _timestamp; - emit MessageSuspended(msgHash, _suspend); + + if (_suspend) { + if (proofReceipt[msgHash].receivedAt == 0) revert B_MESSAGE_NOT_PROVEN(); + if (proofReceipt[msgHash].receivedAt == type(uint64).max) { + revert B_MESSAGE_SUSPENDED(); + } + + proofReceipt[msgHash].receivedAt = type(uint64).max; + emit MessageSuspended(msgHash, true, 0); + } else { + // Note before we set the receivedAt to current timestamp, we have to be really + // careful that this message must have been proven then suspended. + if (proofReceipt[msgHash].receivedAt != type(uint64).max) { + revert B_MESSAGE_NOT_SUSPENDED(); + } + proofReceipt[msgHash].receivedAt = uint64(block.timestamp); + emit MessageSuspended(msgHash, false, uint64(block.timestamp)); + } } } @@ -166,9 +184,12 @@ contract Bridge is EssentialContract, IBridge { if (messageStatus[msgHash] != Status.NEW) revert B_STATUS_MISMATCH(); uint64 receivedAt = proofReceipt[msgHash].receivedAt; - bool isMessageProven = receivedAt != 0; + if (receivedAt == type(uint64).max) revert B_MESSAGE_SUSPENDED(); + + (uint256 invocationDelay,) = getInvocationDelays(); - if (!isMessageProven) { + bool isNewlyProven; + if (receivedAt == 0) { address signalService = resolve("signal_service", false); if (!ISignalService(signalService).isSignalSent(address(this), msgHash)) { @@ -181,10 +202,12 @@ contract Bridge is EssentialContract, IBridge { } receivedAt = uint64(block.timestamp); - proofReceipt[msgHash].receivedAt = receivedAt; - } + isNewlyProven = true; - (uint256 invocationDelay,) = getInvocationDelays(); + if (invocationDelay != 0) { + proofReceipt[msgHash].receivedAt = receivedAt; + } + } if (block.timestamp >= invocationDelay + receivedAt) { delete proofReceipt[msgHash]; @@ -206,7 +229,7 @@ contract Bridge is EssentialContract, IBridge { _message.srcOwner.sendEtherAndVerify(_message.value); } emit MessageRecalled(msgHash); - } else if (!isMessageProven) { + } else if (isNewlyProven) { emit MessageReceived(msgHash, _message, true); } else { revert B_INVOCATION_TOO_EARLY(); @@ -227,17 +250,20 @@ contract Bridge is EssentialContract, IBridge { if (messageStatus[msgHash] != Status.NEW) revert B_STATUS_MISMATCH(); address signalService = resolve("signal_service", false); + uint64 receivedAt = proofReceipt[msgHash].receivedAt; - bool isMessageProven = receivedAt != 0; + if (receivedAt == type(uint64).max) revert B_MESSAGE_SUSPENDED(); (uint256 invocationDelay, uint256 invocationExtraDelay) = getInvocationDelays(); - if (!isMessageProven) { + bool isNewlyProven; + if (receivedAt == 0) { if (!_proveSignalReceived(signalService, msgHash, _message.srcChainId, _proof)) { revert B_NOT_RECEIVED(); } receivedAt = uint64(block.timestamp); + isNewlyProven = true; if (invocationDelay != 0) { proofReceipt[msgHash] = ProofReceipt({ @@ -299,7 +325,7 @@ contract Bridge is EssentialContract, IBridge { refundTo.sendEtherAndVerify(refundAmount); } emit MessageExecuted(msgHash); - } else if (!isMessageProven) { + } else if (isNewlyProven) { emit MessageReceived(msgHash, _message, false); } else { revert B_INVOCATION_TOO_EARLY(); diff --git a/packages/protocol/contracts/bridge/IBridge.sol b/packages/protocol/contracts/bridge/IBridge.sol index f1f9cf83bb..dd2c22ce92 100644 --- a/packages/protocol/contracts/bridge/IBridge.sol +++ b/packages/protocol/contracts/bridge/IBridge.sol @@ -95,7 +95,8 @@ interface IBridge { /// @notice Emitted when a message is suspended or unsuspended. /// @param msgHash The hash of the message. /// @param suspended True if the message is suspended. - event MessageSuspended(bytes32 msgHash, bool suspended); + /// @param receivedAt The received-at timestamp, 0 if suspended is true. + event MessageSuspended(bytes32 msgHash, bool suspended, uint64 receivedAt); /// @notice Emitted when an address is banned or unbanned. /// @param addr The address to ban or unban. diff --git a/packages/protocol/test/bridge/Bridge.t.sol b/packages/protocol/test/bridge/Bridge.t.sol index ce5a9b485f..5dac38d141 100644 --- a/packages/protocol/test/bridge/Bridge.t.sol +++ b/packages/protocol/test/bridge/Bridge.t.sol @@ -737,33 +737,74 @@ contract BridgeTest is TaikoTest { } function test_Bridge_suspend_messages() public { - vm.startPrank(Alice); - (IBridge.Message memory message, bytes memory proof) = - setUpPredefinedSuccessfulProcessMessageCall(); + IBridge.Message memory message = IBridge.Message({ + id: 0, + from: address(bridge), + srcChainId: uint64(block.chainid), + destChainId: destChainId, + srcOwner: Alice, + destOwner: Alice, + to: Alice, + refundTo: Alice, + value: 1000, + fee: 1000, + gasLimit: 1_000_000, + data: "", + memo: "" + }); + // Mocking proof - but obviously it needs to be created in prod + // corresponding to the message + bytes memory proof = hex"00"; - bytes32 msgHash = destChainBridge.hashMessage(message); + vm.chainId(destChainId); + // This in is the first transaction setting the proofReceipt + + bytes32 msgHash = dest2StepBridge.hashMessage(message); bytes32[] memory messageHashes = new bytes32[](1); messageHashes[0] = msgHash; - vm.stopPrank(); + // Unsuspend a msg that has not been suspended will revert + vm.prank(dest2StepBridge.owner()); + vm.expectRevert(Bridge.B_MESSAGE_NOT_SUSPENDED.selector); + dest2StepBridge.suspendMessages(messageHashes, false); + + // Suspend that will revert + vm.prank(dest2StepBridge.owner()); + vm.expectRevert(Bridge.B_MESSAGE_NOT_PROVEN.selector); + dest2StepBridge.suspendMessages(messageHashes, true); + + vm.prank(Bob); + dest2StepBridge.processMessage(message, proof); + // Suspend - vm.prank(destChainBridge.owner(), destChainBridge.owner()); - destChainBridge.suspendMessages(messageHashes, true); + vm.prank(dest2StepBridge.owner()); + dest2StepBridge.suspendMessages(messageHashes, true); - vm.startPrank(Alice); - vm.expectRevert(Bridge.B_INVOCATION_TOO_EARLY.selector); - destChainBridge.processMessage(message, proof); + // Suspend again will revert + vm.prank(dest2StepBridge.owner()); + vm.expectRevert(Bridge.B_MESSAGE_SUSPENDED.selector); + dest2StepBridge.suspendMessages(messageHashes, true); + + // Try to process the message + vm.prank(Alice); + vm.expectRevert(Bridge.B_MESSAGE_SUSPENDED.selector); + dest2StepBridge.processMessage(message, proof); - vm.stopPrank(); // Unsuspend - vm.prank(destChainBridge.owner(), destChainBridge.owner()); - destChainBridge.suspendMessages(messageHashes, false); + vm.prank(dest2StepBridge.owner()); + dest2StepBridge.suspendMessages(messageHashes, false); - vm.startPrank(Alice); - destChainBridge.processMessage(message, proof); + vm.prank(Alice); + vm.expectRevert(Bridge.B_INVOCATION_TOO_EARLY.selector); + dest2StepBridge.processMessage(message, proof); - IBridge.Status status = destChainBridge.messageStatus(msgHash); + // Go in the future and try again + vm.warp(block.timestamp + 30 days); + vm.prank(Alice); + dest2StepBridge.processMessage(message, proof); + + IBridge.Status status = dest2StepBridge.messageStatus(msgHash); assertEq(status == IBridge.Status.DONE, true); } @@ -778,7 +819,7 @@ contract BridgeTest is TaikoTest { vm.stopPrank(); // Ban address - vm.prank(destChainBridge.owner(), destChainBridge.owner()); + vm.prank(destChainBridge.owner()); destChainBridge.banAddress(message.to, true); vm.startPrank(Alice);