diff --git a/contracts/SplitterContract.sol b/contracts/ERC20Splitter.sol similarity index 66% rename from contracts/SplitterContract.sol rename to contracts/ERC20Splitter.sol index 55d6503..f021ad9 100644 --- a/contracts/SplitterContract.sol +++ b/contracts/ERC20Splitter.sol @@ -1,13 +1,16 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.0; +// SPDX-License-Identifier: CC0-1.0 +pragma solidity 0.8.9; import '@openzeppelin/contracts/token/ERC20/IERC20.sol'; import '@openzeppelin/contracts/security/ReentrancyGuard.sol'; contract ERC20Splitter is ReentrancyGuard { + // tokenAddress => userAddress => balance mapping(address => mapping(address => uint256)) public balances; - mapping(address => address[]) private _userTokens; - mapping(address => mapping(address => bool)) private _hasToken; + // userAddress => tokenAddress[] + mapping(address => address[]) private userTokens; + // tokenAddress => boolean + mapping(address => mapping(address => bool)) private hasToken; /** Events **/ @@ -15,12 +18,12 @@ contract ERC20Splitter is ReentrancyGuard { address indexed depositor, address[] tokenAddresses, uint256[] amounts, - uint256[][] shares, + uint16[][] shares, address[][] recipients ); - event Withdraw(address indexed user, address[] tokenAddresses, uint256[] amounts); + event Withdraw(address indexed user, uint256[] amounts); - uint256 public constant MAX_SHARES = 10000; + uint16 public constant MAX_SHARES = 10000; /** External Functions **/ @@ -32,7 +35,7 @@ contract ERC20Splitter is ReentrancyGuard { function deposit( address[] calldata tokenAddresses, uint256[] calldata amounts, - uint256[][] calldata shares, + uint16[][] calldata shares, address[][] calldata recipients ) external payable nonReentrant { require(tokenAddresses.length == amounts.length, 'ERC20Splitter: Invalid input lengths'); @@ -58,41 +61,40 @@ contract ERC20Splitter is ReentrancyGuard { /// @notice Withdraw all tokens that the caller is entitled to. /// Tokens are automatically determined based on previous deposits. function withdraw() external nonReentrant { - address[] storage userTokens = _userTokens[msg.sender]; - require(userTokens.length > 0, 'ERC20Splitter: No tokens to withdraw'); + address payable to = payable(msg.sender); + address[] storage senderTokens = userTokens[to]; - address[] memory withdrawnTokens = new address[](userTokens.length); - uint256[] memory withdrawnAmounts = new uint256[](userTokens.length); - - for (uint256 i = 0; i < userTokens.length; i++) { - address tokenAddress = userTokens[i]; - uint256 amount = balances[tokenAddress][msg.sender]; + if (senderTokens.length == 0) { + return; + } - if (amount > 0) { - balances[tokenAddress][msg.sender] = 0; + uint256[] memory withdrawnAmounts = new uint256[](senderTokens.length); - if (tokenAddress == address(0)) { - (bool success, ) = msg.sender.call{ value: amount }(''); - require(success, 'ERC20Splitter: Failed to send Ether'); - } else { - require(tokenAddress != address(0), 'ERC20Splitter: Invalid token address'); + for (uint256 i = 0; i < senderTokens.length; i++) { + address tokenAddress = senderTokens[i]; + uint256 amount = balances[tokenAddress][to]; - require( - IERC20(tokenAddress).transferFrom(address(this), msg.sender, amount), - 'ERC20Splitter: TransferFrom failed' - ); - } + require(amount > 0, 'ERC20Splitter: Amount to withdraw must be greater than zero'); + balances[tokenAddress][to] = 0; - withdrawnTokens[i] = tokenAddress; - withdrawnAmounts[i] = amount; + if (tokenAddress == address(0)) { + (bool success, ) = to.call{ value: amount }(''); + require(success, 'ERC20Splitter: Failed to send Ether'); + } else { + require( + IERC20(tokenAddress).transferFrom(address(this), to, amount), + 'ERC20Splitter: TransferFrom failed' + ); } - delete _hasToken[msg.sender][tokenAddress]; + withdrawnAmounts[i] = amount; + + delete hasToken[to][tokenAddress]; } - delete _userTokens[msg.sender]; + delete userTokens[to]; - emit Withdraw(msg.sender, withdrawnTokens, withdrawnAmounts); + emit Withdraw(to, withdrawnAmounts); } /** Internal Functions **/ @@ -105,7 +107,7 @@ contract ERC20Splitter is ReentrancyGuard { function _splitTokens( address tokenAddress, uint256 amount, - uint256[] calldata shares, + uint16[] calldata shares, address[] calldata recipients ) internal { require(shares.length == recipients.length, 'ERC20Splitter: Shares and recipients length mismatch'); @@ -138,9 +140,9 @@ contract ERC20Splitter is ReentrancyGuard { /// @param recipient The recipient of the token. /// @param tokenAddress The address of the token. function _addTokenForUser(address recipient, address tokenAddress) internal { - if (!_hasToken[recipient][tokenAddress]) { - _userTokens[recipient].push(tokenAddress); - _hasToken[recipient][tokenAddress] = true; + if (!hasToken[recipient][tokenAddress]) { + userTokens[recipient].push(tokenAddress); + hasToken[recipient][tokenAddress] = true; } } } diff --git a/contracts/mocks/MaliciousERC20.sol b/contracts/mocks/MaliciousERC20.sol new file mode 100644 index 0000000..dc655c5 --- /dev/null +++ b/contracts/mocks/MaliciousERC20.sol @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: CC0-1.0 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MaliciousERC20 is ERC20 { + constructor() ERC20("MaliciousToken", "MTK") {} + + function transferFrom( + address sender, + address recipient, + uint256 amount + ) public override returns (bool) { + return false; + } + + // Mint function for testing purposes + function mint(address to, uint256 amount) external { + _mint(to, amount); + } +} diff --git a/contracts/mocks/MaliciousRecipient.sol b/contracts/mocks/MaliciousRecipient.sol new file mode 100644 index 0000000..2bd7ac7 --- /dev/null +++ b/contracts/mocks/MaliciousRecipient.sol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: CC0-1.0 + +// contracts/MaliciousRecipient.sol +pragma solidity ^0.8.0; + +contract MaliciousRecipient { + // Fallback function that reverts when receiving Ether + fallback() external payable { + revert("MaliciousRecipient: Reverting on receive"); + } + + receive() external payable { + revert("MaliciousRecipient: Reverting on receive"); + } +} diff --git a/test/SplitterContract.test.ts b/test/SplitterContract.test.ts index 3686e6f..af97e36 100644 --- a/test/SplitterContract.test.ts +++ b/test/SplitterContract.test.ts @@ -2,42 +2,54 @@ import { ethers, network } from 'hardhat' import { loadFixture } from '@nomicfoundation/hardhat-network-helpers' import { expect } from 'chai' -import { MockERC20, ERC20Splitter } from '../typechain-types' +import { MockERC20, ERC20Splitter, MaliciousRecipient, MaliciousERC20 } from '../typechain-types' import { AddressZero } from '../utils/constants' describe('ERC20Splitter', () => { let splitter: ERC20Splitter let mockERC20: MockERC20 + let maliciousERC20: MaliciousERC20 let owner: Awaited> let recipient1: Awaited> let recipient2: Awaited> let recipient3: Awaited> + let anotherUser: Awaited> + let maliciousRecipient: MaliciousRecipient const tokenAmount = ethers.parseEther('100') const ethAmount = ethers.parseEther('1') before(async function () { // prettier-ignore - [owner, recipient1, recipient2, recipient3] = await ethers.getSigners() + [owner, recipient1, recipient2, recipient3, anotherUser] = await ethers.getSigners() }) async function deploySplitterContracts() { const MockERC20 = await ethers.getContractFactory('MockERC20') const ERC20Splitter = await ethers.getContractFactory('ERC20Splitter') + const MaliciousRecipientFactory = await ethers.getContractFactory('MaliciousRecipient') + maliciousRecipient = await MaliciousRecipientFactory.deploy() + await maliciousRecipient.waitForDeployment() + + const MaliciousERC20Factory = await ethers.getContractFactory('MaliciousERC20') + maliciousERC20 = await MaliciousERC20Factory.deploy() + await maliciousERC20.waitForDeployment() + const mockERC20 = await MockERC20.deploy() await mockERC20.waitForDeployment() const splitter = await ERC20Splitter.deploy() await splitter.waitForDeployment() - return { mockERC20, splitter } + return { mockERC20, splitter, maliciousERC20 } } beforeEach(async () => { const contracts = await loadFixture(deploySplitterContracts) mockERC20 = contracts.mockERC20 splitter = contracts.splitter + maliciousERC20 = contracts.maliciousERC20 // Mint tokens to the owner await mockERC20.connect(owner).mint(owner, ethers.parseEther('1000')) @@ -61,6 +73,9 @@ describe('ERC20Splitter', () => { method: 'hardhat_stopImpersonatingAccount', params: [splitterAddress], }) + + const tokenAmount = ethers.parseEther('100') + await maliciousERC20.mint(splitter, tokenAmount) }) describe('Main Functions', async () => { @@ -115,6 +130,91 @@ describe('ERC20Splitter', () => { ).to.be.revertedWith('ERC20Splitter: Shares and recipients length mismatch') }) + it('Should revert when msg.value does not match the expected Ether amount', async () => { + const incorrectMsgValue = ethers.parseEther('1') // Incorrect Ether amount + const correctEtherAmount = ethers.parseEther('2') // Correct Ether amount to be split + const tokenAddresses = [ethers.ZeroAddress] // Using address(0) for Ether + const amounts = [correctEtherAmount] // Amount to split among recipients + const shares = [[5000, 3000, 2000]] // Shares summing up to 100% + const recipients = [[recipient1.address, recipient2.address, recipient3.address]] + + await expect( + splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients, { + value: incorrectMsgValue, // Sending incorrect msg.value + }), + ).to.be.revertedWith('ERC20Splitter: Incorrect native token amount sent') + }) + it('Should revert when amount is 0', async () => { + const incorrectMsgValue = ethers.parseEther('1') // Incorrect Ether amount + const correctEtherAmount = ethers.parseEther('2') // Correct Ether amount to be split + const tokenAddresses = [ethers.ZeroAddress] // Using address(0) for Ether + const amounts = [correctEtherAmount] // Amount to split among recipients + const shares = [[5000, 3000, 2000]] // Shares summing up to 100% + const recipients = [[recipient1.address, recipient2.address, recipient3.address]] + + await expect( + splitter.connect(owner).deposit(tokenAddresses, [0], shares, recipients, { + value: incorrectMsgValue, // Sending incorrect msg.value + }), + ).to.be.revertedWith('ERC20Splitter: Amount must be greater than zero') + }) + + it('Should revert when tokenAddresses and amounts lengths mismatch', async () => { + const tokenAddresses = [mockERC20.getAddress(), ethers.ZeroAddress] + const amounts = [ethers.parseEther('100')] // Length 1, intentional mismatch + const shares = [[5000, 3000, 2000]] // Correct length + const recipients = [[recipient1.address, recipient2.address, recipient3.address]] + + await expect( + splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients, { + value: ethers.parseEther('0'), // No Ether sent + }), + ).to.be.revertedWith('ERC20Splitter: Invalid input lengths') + }) + + it('Should revert when tokenAddresses, shares, and recipients lengths mismatch', async () => { + const tokenAddresses = [mockERC20.getAddress(), ethers.ZeroAddress] + const amounts = [ethers.parseEther('100'), ethers.parseEther('2')] + const shares = [ + [5000, 3000, 2000], // Length 1 + ] // Length 1 (intentional mismatch) + const recipients = [ + [recipient1.address, recipient2.address, recipient3.address], + [recipient1.address, recipient2.address, recipient3.address], + ] // Length 2 + + await expect( + splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients, { + value: ethers.parseEther('2'), + }), + ).to.be.revertedWith('ERC20Splitter: Mismatched input sizes') + }) + + it('Should revert when shares and recipients lengths mismatch within sub-arrays', async () => { + const tokenAddresses = [mockERC20.getAddress()] // Length 1 + const amounts = [ethers.parseEther('100')] // Length 1 + const shares = [[5000, 3000, 2000]] // Length 1, sub-array length 3 + const recipients = [ + [recipient1.address, recipient2.address], // Length mismatch in sub-array + ] // Length 1, sub-array length 2 + + await expect(splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients)).to.be.revertedWith( + 'ERC20Splitter: Shares and recipients length mismatch', + ) + }) + + it('Should revert when ERC20 transferFrom fails during withdrawal', async () => { + const tokenAmount = ethers.parseEther('100') + const tokenAddresses = [await maliciousERC20.getAddress()] + const amounts = [tokenAmount] + const shares = [[10000]] // 100% share + const recipients = [[recipient1.address]] + + await expect(splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients)).to.be.revertedWith( + 'ERC20Splitter: Transfer failed', + ) + }) + it('Should handle multiple native token (ETH) deposits in a single transaction', async () => { const ethShares = [ [5000, 5000], @@ -196,7 +296,7 @@ describe('ERC20Splitter', () => { it('Should allow a recipient to withdraw their split ERC20 tokens without specifying token addresses', async () => { await expect(splitter.connect(recipient1).withdraw()) .to.emit(splitter, 'Withdraw') - .withArgs(recipient1.address, [await mockERC20.getAddress()], [ethers.parseEther('50')]) + .withArgs(recipient1.address, [ethers.parseEther('50')]) expect(await splitter.balances(await mockERC20.getAddress(), recipient1.address)).to.equal(0) }) @@ -213,13 +313,47 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient1.address, - [await mockERC20.getAddress(), AddressZero], // Expect both ERC-20 and native token [ethers.parseEther('50'), ethers.parseEther('0.5')], // 50 ERC20 tokens and 0.5 ETH ) expect(await splitter.balances(AddressZero, recipient1.address)).to.equal(0) expect(await splitter.balances(mockERC20.getAddress(), recipient1.address)).to.equal(0) }) + + it('Should handle withdraw() when user has no tokens', async () => { + await splitter.connect(anotherUser).withdraw() + }) + + it('Should revert when sending Ether to a recipient fails', async () => { + const malicious = await maliciousRecipient.getAddress() + + await network.provider.request({ + method: 'hardhat_impersonateAccount', + params: [malicious], + }) + + const ethAmount = ethers.parseEther('1') + const tokenAddresses = [ethers.ZeroAddress] // Ether represented by address zero + const amounts = [ethAmount] + const shares = [[10000]] // 100% share + const recipients = [[maliciousRecipient.getAddress()]] + + await splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients, { + value: ethAmount, + }) + + const maliciousSigner = await ethers.getSigner(malicious) + + await network.provider.send('hardhat_setBalance', [ + malicious, + ethers.toQuantity(ethers.parseEther('1')), // Setting 2 Ether + ]) + + // Attempt to withdraw as the malicious recipient + await expect(splitter.connect(maliciousSigner).withdraw()).to.be.revertedWith( + 'ERC20Splitter: Failed to send Ether', + ) + }) }) describe('Withdraw ERC-20 and Native Tokens', async () => { @@ -234,7 +368,7 @@ describe('ERC20Splitter', () => { it('Should allow a recipient to withdraw their split ERC20 tokens without specifying token addresses', async () => { await expect(splitter.connect(recipient1).withdraw()) .to.emit(splitter, 'Withdraw') - .withArgs(recipient1.address, [await mockERC20.getAddress()], [ethers.parseEther('50')]) + .withArgs(recipient1.address, [ethers.parseEther('50')]) expect(await splitter.balances(mockERC20.getAddress(), recipient1.address)).to.equal(0) }) @@ -250,8 +384,7 @@ describe('ERC20Splitter', () => { await expect(splitter.connect(recipient1).withdraw()) .to.emit(splitter, 'Withdraw') .withArgs( - recipient1.address, - [await mockERC20.getAddress(), AddressZero], // Expect both ERC-20 and native token + recipient1.address, // Expect both ERC-20 and native token [ethers.parseEther('50'), ethers.parseEther('0.5')], // 50 ERC20 tokens and 0.5 ETH ) @@ -275,7 +408,6 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient1.address, - [AddressZero], // Expect only native token (ETH) [ethers.parseEther('0.5')], // Expect 0.5 ETH (50% of 1 ETH) ) @@ -307,7 +439,6 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient1.address, - [AddressZero], // Only native token (ETH) [ethers.parseEther('1')], // Full 1 ETH ) @@ -317,7 +448,6 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient2.address, - [await mockERC20.getAddress()], // Only ERC-20 token [ethers.parseEther('50')], // 50% of ERC-20 tokens ) @@ -327,7 +457,6 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient3.address, - [await mockERC20.getAddress()], // Only ERC-20 token [ethers.parseEther('50')], // 50% of ERC-20 tokens ) @@ -368,7 +497,6 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient1.address, - [AddressZero], // Only native token (ETH) [ethers.parseEther('0.5')], // 50% of 1 ETH ) @@ -380,7 +508,6 @@ describe('ERC20Splitter', () => { .to.emit(splitter, 'Withdraw') .withArgs( recipient2.address, - [AddressZero, await mockERC20.getAddress()], // First ETH, then ERC-20 [ethers.parseEther('0.5'), ethers.parseEther('60')], // 50% of 1 ETH and 60 ERC-20 tokens ) @@ -391,7 +518,7 @@ describe('ERC20Splitter', () => { it('Should allow recipient3 to withdraw only ERC-20 tokens', async () => { await expect(splitter.connect(recipient3).withdraw()) .to.emit(splitter, 'Withdraw') - .withArgs(recipient3.address, [await mockERC20.getAddress()], [ethers.parseEther('40')]) + .withArgs(recipient3.address, [ethers.parseEther('40')]) expect(await splitter.balances(mockERC20.getAddress(), recipient3.address)).to.equal(0) })