diff --git a/contracts/ERC20Splitter.sol b/contracts/ERC20Splitter.sol index f021ad9..ae4f863 100644 --- a/contracts/ERC20Splitter.sol +++ b/contracts/ERC20Splitter.sol @@ -74,7 +74,6 @@ contract ERC20Splitter is ReentrancyGuard { address tokenAddress = senderTokens[i]; uint256 amount = balances[tokenAddress][to]; - require(amount > 0, 'ERC20Splitter: Amount to withdraw must be greater than zero'); balances[tokenAddress][to] = 0; if (tokenAddress == address(0)) { diff --git a/test/SplitterContract.test.ts b/test/SplitterContract.test.ts index af97e36..7b0a015 100644 --- a/test/SplitterContract.test.ts +++ b/test/SplitterContract.test.ts @@ -2,13 +2,12 @@ import { ethers, network } from 'hardhat' import { loadFixture } from '@nomicfoundation/hardhat-network-helpers' import { expect } from 'chai' -import { MockERC20, ERC20Splitter, MaliciousRecipient, MaliciousERC20 } from '../typechain-types' +import { MockERC20, ERC20Splitter, MaliciousRecipient } from '../typechain-types' import { AddressZero } from '../utils/constants' describe('ERC20Splitter', () => { let splitter: ERC20Splitter let mockERC20: MockERC20 - let maliciousERC20: MaliciousERC20 let owner: Awaited> let recipient1: Awaited> let recipient2: Awaited> @@ -32,24 +31,19 @@ describe('ERC20Splitter', () => { maliciousRecipient = await MaliciousRecipientFactory.deploy() await maliciousRecipient.waitForDeployment() - const MaliciousERC20Factory = await ethers.getContractFactory('MaliciousERC20') - maliciousERC20 = await MaliciousERC20Factory.deploy() - await maliciousERC20.waitForDeployment() - const mockERC20 = await MockERC20.deploy() await mockERC20.waitForDeployment() const splitter = await ERC20Splitter.deploy() await splitter.waitForDeployment() - return { mockERC20, splitter, maliciousERC20 } + return { mockERC20, splitter } } beforeEach(async () => { const contracts = await loadFixture(deploySplitterContracts) mockERC20 = contracts.mockERC20 splitter = contracts.splitter - maliciousERC20 = contracts.maliciousERC20 // Mint tokens to the owner await mockERC20.connect(owner).mint(owner, ethers.parseEther('1000')) @@ -75,7 +69,7 @@ describe('ERC20Splitter', () => { }) const tokenAmount = ethers.parseEther('100') - await maliciousERC20.mint(splitter, tokenAmount) + await mockERC20.mint(splitter, tokenAmount) }) describe('Main Functions', async () => { @@ -203,13 +197,14 @@ describe('ERC20Splitter', () => { ) }) - it('Should revert when ERC20 transferFrom fails during withdrawal', async () => { + it('Should revert when ERC20 transferFrom fails during deposit', async () => { const tokenAmount = ethers.parseEther('100') - const tokenAddresses = [await maliciousERC20.getAddress()] + const tokenAddresses = [await mockERC20.getAddress()] const amounts = [tokenAmount] const shares = [[10000]] // 100% share const recipients = [[recipient1.address]] + await mockERC20.transferReverts(true, 0) await expect(splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients)).to.be.revertedWith( 'ERC20Splitter: Transfer failed', ) @@ -354,6 +349,34 @@ describe('ERC20Splitter', () => { 'ERC20Splitter: Failed to send Ether', ) }) + it('Should revert when ERC20 transferFrom fails during withdraw', async () => { + const mockERC20false = await mockERC20.getAddress() + + await network.provider.request({ + method: 'hardhat_impersonateAccount', + params: [mockERC20false], + }) + + const ethAmount = ethers.parseEther('1') + const tokenAddresses = [ethers.ZeroAddress] // Ether represented by address zero + const amounts = [ethAmount] + const shares = [[10000]] // 100% share + const recipients = [[recipient1.getAddress()]] + + await splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients, { + value: ethAmount, + }) + + await network.provider.send('hardhat_setBalance', [ + mockERC20false, + ethers.toQuantity(ethers.parseEther('1')), // Setting 2 Ether + ]) + + await mockERC20.transferReverts(true, 0) + + // Attempt to withdraw as the malicious recipient + await expect(splitter.connect(recipient1).withdraw()).to.be.revertedWith('ERC20Splitter: TransferFrom failed') + }) }) describe('Withdraw ERC-20 and Native Tokens', async () => {