diff --git a/contracts/ERC20Splitter.sol b/contracts/ERC20Splitter.sol index 50e1936..eeabdaf 100644 --- a/contracts/ERC20Splitter.sol +++ b/contracts/ERC20Splitter.sol @@ -9,7 +9,7 @@ contract ERC20Splitter is ReentrancyGuard { mapping(address => mapping(address => uint256)) public balances; // userAddress => tokenAddress[] mapping(address => address[]) private userTokens; - // tokenAddress => boolean + // tokenAddress => userAddress => boolean mapping(address => mapping(address => bool)) private hasToken; /** Events **/ @@ -23,6 +23,14 @@ contract ERC20Splitter is ReentrancyGuard { ); event Withdraw(address indexed user, address[] tokenAddresses, uint256[] amounts); + event RecipientSplit( + address indexed depositor, + address indexed tokenAddress, + address indexed recipient, + uint256 amount, + uint16 sharePercentage + ); + uint16 public constant MAX_SHARES = 10000; /** External Functions **/ @@ -61,8 +69,8 @@ contract ERC20Splitter is ReentrancyGuard { /// @notice Withdraw all tokens that the caller is entitled to. /// Tokens are automatically determined based on previous deposits. function withdraw() external nonReentrant { - address payable to = payable(msg.sender); - address[] storage senderTokens = userTokens[msg.sender]; + address recipient = msg.sender; + address[] storage senderTokens = userTokens[recipient]; if (senderTokens.length == 0) { return; @@ -72,26 +80,26 @@ contract ERC20Splitter is ReentrancyGuard { for (uint256 i = 0; i < senderTokens.length; i++) { address tokenAddress = senderTokens[i]; - uint256 amount = balances[tokenAddress][to]; + uint256 amount = balances[tokenAddress][recipient]; - balances[tokenAddress][to] = 0; + balances[tokenAddress][recipient] = 0; if (tokenAddress == address(0)) { - to.transfer(amount); + payable(msg.sender).transfer(amount); } else { require( - IERC20(tokenAddress).transferFrom(address(this), to, amount), + IERC20(tokenAddress).transferFrom(address(this), recipient, amount), 'ERC20Splitter: TransferFrom failed' ); } withdrawnAmounts[i] = amount; - delete hasToken[to][tokenAddress]; + delete hasToken[recipient][tokenAddress]; } - emit Withdraw(to, userTokens[msg.sender], withdrawnAmounts); + emit Withdraw(recipient, userTokens[recipient], withdrawnAmounts); - delete userTokens[to]; + delete userTokens[recipient]; } /** Internal Functions **/ @@ -130,6 +138,8 @@ contract ERC20Splitter is ReentrancyGuard { balances[tokenAddress][recipients[i]] += recipientAmount; _addTokenForUser(recipients[i], tokenAddress); + + emit RecipientSplit(msg.sender, tokenAddress, recipients[i], recipientAmount, shares[i]); } } diff --git a/test/SplitterContract.test.ts b/test/SplitterContract.test.ts index 7c763a4..ad40501 100644 --- a/test/SplitterContract.test.ts +++ b/test/SplitterContract.test.ts @@ -264,6 +264,46 @@ describe('ERC20Splitter', () => { // Check balances for recipient3 (40% of 100 ERC-20 tokens = 40 tokens) expect(await splitter.balances(mockERC20.getAddress(), recipient3.address)).to.equal(ethers.parseEther('40')) }) + it('Should emit RecipientSplit events for each recipient on deposit', async function () { + const mockAddress = await mockERC20.getAddress() + const tokenAmount = ethers.parseEther('100') + const ethAmount = ethers.parseEther('1') + + const tokenAddresses = [await mockERC20.getAddress(), AddressZero] + const amounts = [tokenAmount, ethAmount] + const shares = [ + [5000, 3000, 2000], // For ERC20 token + [7000, 2000, 1000], // For ETH + ] + const recipients = [ + [recipient1.address, recipient2.address, recipient3.address], + [recipient1.address, recipient2.address, recipient3.address], + ] + + await expect(splitter.connect(owner).deposit(tokenAddresses, amounts, shares, recipients, { value: ethAmount })) + .to.emit(splitter, 'Deposit') + .withArgs(owner.address, tokenAddresses, amounts, shares, recipients) + .and.to.emit(splitter, 'RecipientSplit') + .withArgs(owner.address, mockAddress, recipient1.address, ethers.parseEther('50'), 5000) + .and.to.emit(splitter, 'RecipientSplit') + .withArgs(owner.address, mockAddress, recipient2.address, ethers.parseEther('30'), 3000) + .and.to.emit(splitter, 'RecipientSplit') + .withArgs(owner.address, mockAddress, recipient3.address, ethers.parseEther('20'), 2000) + .and.to.emit(splitter, 'RecipientSplit') + .withArgs(owner.address, AddressZero, recipient1.address, ethers.parseEther('0.7'), 7000) + .and.to.emit(splitter, 'RecipientSplit') + .withArgs(owner.address, AddressZero, recipient2.address, ethers.parseEther('0.2'), 2000) + .and.to.emit(splitter, 'RecipientSplit') + .withArgs(owner.address, AddressZero, recipient3.address, ethers.parseEther('0.1'), 1000) + + expect(await splitter.balances(mockAddress, recipient1.address)).to.equal(ethers.parseEther('50')) + expect(await splitter.balances(mockAddress, recipient2.address)).to.equal(ethers.parseEther('30')) + expect(await splitter.balances(mockAddress, recipient3.address)).to.equal(ethers.parseEther('20')) + + expect(await splitter.balances(AddressZero, recipient1.address)).to.equal(ethers.parseEther('0.7')) + expect(await splitter.balances(AddressZero, recipient2.address)).to.equal(ethers.parseEther('0.2')) + expect(await splitter.balances(AddressZero, recipient3.address)).to.equal(ethers.parseEther('0.1')) + }) }) describe('Withdraw', async () => { @@ -332,7 +372,6 @@ describe('ERC20Splitter', () => { await mockERC20.transferReverts(true, 0) - // Attempt to withdraw as the malicious recipient await expect(splitter.connect(recipient1).withdraw()).to.be.revertedWith('ERC20Splitter: TransferFrom failed') }) })