Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(protocol): fix bridge bug caused by incorrect check of receivedAt (by OZ) #16545

Merged
merged 20 commits into from
Mar 28, 2024
Merged
50 changes: 38 additions & 12 deletions packages/protocol/contracts/bridge/Bridge.sol
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ contract Bridge is EssentialContract, IBridge {
error B_INVALID_STATUS();
error B_INVALID_USER();
error B_INVALID_VALUE();
error B_MESSAGE_NOT_PROVEN();
error B_MESSAGE_NOT_SENT();
error B_MESSAGE_NOT_SUSPENDED();
error B_MESSAGE_SUSPENDED();
error B_NON_RETRIABLE();
error B_NOT_FAILED();
error B_NOT_RECEIVED();
Expand Down Expand Up @@ -86,11 +89,26 @@ contract Bridge is EssentialContract, IBridge {
external
onlyFromOwnerOrNamed("bridge_watchdog")
{
uint64 _timestamp = _suspend ? type(uint64).max : uint64(block.timestamp);
for (uint256 i; i < _msgHashes.length; ++i) {
bytes32 msgHash = _msgHashes[i];
proofReceipt[msgHash].receivedAt = _timestamp;
emit MessageSuspended(msgHash, _suspend);

if (_suspend) {
if (proofReceipt[msgHash].receivedAt == 0) revert B_MESSAGE_NOT_PROVEN();
if (proofReceipt[msgHash].receivedAt == type(uint64).max) {
revert B_MESSAGE_SUSPENDED();
}

proofReceipt[msgHash].receivedAt = type(uint64).max;
emit MessageSuspended(msgHash, true, 0);
} else {
// Note before we set the receivedAt to current timestamp, we have to be really
// careful that this message must have been proven then suspended.
if (proofReceipt[msgHash].receivedAt != type(uint64).max) {
revert B_MESSAGE_NOT_SUSPENDED();
}
proofReceipt[msgHash].receivedAt = uint64(block.timestamp);
emit MessageSuspended(msgHash, false, uint64(block.timestamp));
}
}
}

Expand Down Expand Up @@ -166,9 +184,12 @@ contract Bridge is EssentialContract, IBridge {
if (messageStatus[msgHash] != Status.NEW) revert B_STATUS_MISMATCH();

uint64 receivedAt = proofReceipt[msgHash].receivedAt;
bool isMessageProven = receivedAt != 0;
if (receivedAt == type(uint64).max) revert B_MESSAGE_SUSPENDED();

(uint256 invocationDelay,) = getInvocationDelays();

if (!isMessageProven) {
bool isNewlyProven;
Dismissed Show dismissed Hide dismissed
if (receivedAt == 0) {
address signalService = resolve("signal_service", false);

if (!ISignalService(signalService).isSignalSent(address(this), msgHash)) {
Expand All @@ -181,10 +202,12 @@ contract Bridge is EssentialContract, IBridge {
}

receivedAt = uint64(block.timestamp);
proofReceipt[msgHash].receivedAt = receivedAt;
}
isNewlyProven = true;

(uint256 invocationDelay,) = getInvocationDelays();
if (invocationDelay != 0) {
proofReceipt[msgHash].receivedAt = receivedAt;
}
}

if (block.timestamp >= invocationDelay + receivedAt) {
delete proofReceipt[msgHash];
Expand All @@ -206,7 +229,7 @@ contract Bridge is EssentialContract, IBridge {
_message.srcOwner.sendEtherAndVerify(_message.value);
}
emit MessageRecalled(msgHash);
} else if (!isMessageProven) {
} else if (isNewlyProven) {
emit MessageReceived(msgHash, _message, true);
} else {
revert B_INVOCATION_TOO_EARLY();
Expand All @@ -227,17 +250,20 @@ contract Bridge is EssentialContract, IBridge {
if (messageStatus[msgHash] != Status.NEW) revert B_STATUS_MISMATCH();

address signalService = resolve("signal_service", false);

uint64 receivedAt = proofReceipt[msgHash].receivedAt;
bool isMessageProven = receivedAt != 0;
if (receivedAt == type(uint64).max) revert B_MESSAGE_SUSPENDED();

(uint256 invocationDelay, uint256 invocationExtraDelay) = getInvocationDelays();

if (!isMessageProven) {
bool isNewlyProven;
Dismissed Show dismissed Hide dismissed
if (receivedAt == 0) {
if (!_proveSignalReceived(signalService, msgHash, _message.srcChainId, _proof)) {
revert B_NOT_RECEIVED();
}

receivedAt = uint64(block.timestamp);
isNewlyProven = true;

if (invocationDelay != 0) {
proofReceipt[msgHash] = ProofReceipt({
Expand Down Expand Up @@ -299,7 +325,7 @@ contract Bridge is EssentialContract, IBridge {
refundTo.sendEtherAndVerify(refundAmount);
}
emit MessageExecuted(msgHash);
} else if (!isMessageProven) {
} else if (isNewlyProven) {
emit MessageReceived(msgHash, _message, false);
} else {
revert B_INVOCATION_TOO_EARLY();
Expand Down
3 changes: 2 additions & 1 deletion packages/protocol/contracts/bridge/IBridge.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ interface IBridge {
/// @notice Emitted when a message is suspended or unsuspended.
/// @param msgHash The hash of the message.
/// @param suspended True if the message is suspended.
event MessageSuspended(bytes32 msgHash, bool suspended);
/// @param receivedAt The received-at timestamp, 0 if suspended is true.
event MessageSuspended(bytes32 msgHash, bool suspended, uint64 receivedAt);

/// @notice Emitted when an address is banned or unbanned.
/// @param addr The address to ban or unban.
Expand Down
75 changes: 58 additions & 17 deletions packages/protocol/test/bridge/Bridge.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -737,33 +737,74 @@ contract BridgeTest is TaikoTest {
}

function test_Bridge_suspend_messages() public {
vm.startPrank(Alice);
(IBridge.Message memory message, bytes memory proof) =
setUpPredefinedSuccessfulProcessMessageCall();
IBridge.Message memory message = IBridge.Message({
id: 0,
from: address(bridge),
srcChainId: uint64(block.chainid),
destChainId: destChainId,
srcOwner: Alice,
destOwner: Alice,
to: Alice,
refundTo: Alice,
value: 1000,
fee: 1000,
gasLimit: 1_000_000,
data: "",
memo: ""
});
// Mocking proof - but obviously it needs to be created in prod
// corresponding to the message
bytes memory proof = hex"00";

bytes32 msgHash = destChainBridge.hashMessage(message);
vm.chainId(destChainId);
// This in is the first transaction setting the proofReceipt

bytes32 msgHash = dest2StepBridge.hashMessage(message);
bytes32[] memory messageHashes = new bytes32[](1);
messageHashes[0] = msgHash;

vm.stopPrank();
// Unsuspend a msg that has not been suspended will revert
vm.prank(dest2StepBridge.owner());
vm.expectRevert(Bridge.B_MESSAGE_NOT_SUSPENDED.selector);
dest2StepBridge.suspendMessages(messageHashes, false);

// Suspend that will revert
vm.prank(dest2StepBridge.owner());
vm.expectRevert(Bridge.B_MESSAGE_NOT_PROVEN.selector);
dest2StepBridge.suspendMessages(messageHashes, true);

vm.prank(Bob);
dest2StepBridge.processMessage(message, proof);

// Suspend
vm.prank(destChainBridge.owner(), destChainBridge.owner());
destChainBridge.suspendMessages(messageHashes, true);
vm.prank(dest2StepBridge.owner());
dest2StepBridge.suspendMessages(messageHashes, true);

vm.startPrank(Alice);
vm.expectRevert(Bridge.B_INVOCATION_TOO_EARLY.selector);
destChainBridge.processMessage(message, proof);
// Suspend again will revert
vm.prank(dest2StepBridge.owner());
vm.expectRevert(Bridge.B_MESSAGE_SUSPENDED.selector);
dest2StepBridge.suspendMessages(messageHashes, true);

// Try to process the message
vm.prank(Alice);
vm.expectRevert(Bridge.B_MESSAGE_SUSPENDED.selector);
dest2StepBridge.processMessage(message, proof);

vm.stopPrank();
// Unsuspend
vm.prank(destChainBridge.owner(), destChainBridge.owner());
destChainBridge.suspendMessages(messageHashes, false);
vm.prank(dest2StepBridge.owner());
dest2StepBridge.suspendMessages(messageHashes, false);

vm.startPrank(Alice);
destChainBridge.processMessage(message, proof);
vm.prank(Alice);
vm.expectRevert(Bridge.B_INVOCATION_TOO_EARLY.selector);
dest2StepBridge.processMessage(message, proof);

IBridge.Status status = destChainBridge.messageStatus(msgHash);
// Go in the future and try again
vm.warp(block.timestamp + 30 days);

vm.prank(Alice);
dest2StepBridge.processMessage(message, proof);

IBridge.Status status = dest2StepBridge.messageStatus(msgHash);
assertEq(status == IBridge.Status.DONE, true);
}

Expand All @@ -778,7 +819,7 @@ contract BridgeTest is TaikoTest {

vm.stopPrank();
// Ban address
vm.prank(destChainBridge.owner(), destChainBridge.owner());
vm.prank(destChainBridge.owner());
destChainBridge.banAddress(message.to, true);

vm.startPrank(Alice);
Expand Down
Loading