Skip to content

Commit

Permalink
Merge pull request #26 from euler-xyz/feature-request
Browse files Browse the repository at this point in the history
Change `enableReward` and `disableReward` interface, add `isRewardEnabled`
  • Loading branch information
kasperpawlowski authored Aug 6, 2024
2 parents 0139e23 + dca04a2 commit 014e65e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 35 deletions.
59 changes: 50 additions & 9 deletions src/BaseRewardStreams.sol
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
/// @param rewarded The address of the rewarded token.
/// @param reward The address of the reward token.
/// @param recipient The address to receive the spillover reward tokens.
function updateReward(address rewarded, address reward, address recipient) public virtual override {
/// @return The amount of the spillover reward tokens claimed.
function updateReward(
address rewarded,
address reward,
address recipient
) public virtual override returns (uint256) {
address msgSender = _msgSender();

// If the account disables the rewards we pass an account balance of zero to not accrue any.
Expand All @@ -221,8 +226,10 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
);

if (recipient != address(0)) {
claim(address(0), rewarded, reward, recipient);
return claim(address(0), rewarded, reward, recipient);
}

return 0;
}

/// @notice Claims earned reward.
Expand All @@ -231,12 +238,13 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
/// @param reward The address of the reward token.
/// @param recipient The address to receive the claimed reward tokens.
/// @param forfeitRecentReward Whether to forfeit the recent rewards and not update the accumulator.
/// @return The amount of the claimed reward tokens.
function claimReward(
address rewarded,
address reward,
address recipient,
bool forfeitRecentReward
) external virtual override nonReentrant {
) external virtual override nonReentrant returns (uint256) {
address msgSender = _msgSender();

// If the account disables the rewards we pass an account balance of zero to not accrue any.
Expand All @@ -252,19 +260,21 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
forfeitRecentReward
);

claim(msgSender, rewarded, reward, recipient);
return claim(msgSender, rewarded, reward, recipient);
}

/// @notice Enable reward token.
/// @dev There can be at most MAX_REWARDS_ENABLED rewards enabled for the reward token and the account.
/// @param rewarded The address of the rewarded token.
/// @param reward The address of the reward token.
function enableReward(address rewarded, address reward) external virtual override {
/// @return Whether the reward token was enabled.
function enableReward(address rewarded, address reward) external virtual override returns (bool) {
address msgSender = _msgSender();
AccountStorage storage accountStorage = accounts[msgSender][rewarded];
SetStorage storage accountEnabledRewards = accountStorage.enabledRewards;
bool wasEnabled = accountEnabledRewards.insert(reward);

if (accountEnabledRewards.insert(reward)) {
if (wasEnabled) {
if (accountEnabledRewards.numElements > MAX_REWARDS_ENABLED) {
revert TooManyRewardsEnabled();
}
Expand All @@ -280,17 +290,25 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard

emit RewardEnabled(msgSender, rewarded, reward);
}

return wasEnabled;
}

/// @notice Disable reward token.
/// @param rewarded The address of the rewarded token.
/// @param reward The address of the reward token.
/// @param forfeitRecentReward Whether to forfeit the recent rewards and not update the accumulator.
function disableReward(address rewarded, address reward, bool forfeitRecentReward) external virtual override {
/// @return Whether the reward token was disabled.
function disableReward(
address rewarded,
address reward,
bool forfeitRecentReward
) external virtual override returns (bool) {
address msgSender = _msgSender();
AccountStorage storage accountStorage = accounts[msgSender][rewarded];
bool wasDisabled = accountStorage.enabledRewards.remove(reward);

if (accountStorage.enabledRewards.remove(reward)) {
if (wasDisabled) {
DistributionStorage storage distributionStorage = distributions[rewarded][reward];
uint256 currentAccountBalance = accountStorage.balance;

Expand All @@ -307,6 +325,8 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard

emit RewardDisabled(msgSender, rewarded, reward);
}

return wasDisabled;
}

/// @notice Returns the earned reward token amount for a specific account and rewarded token.
Expand Down Expand Up @@ -348,6 +368,19 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
return accounts[account][rewarded].enabledRewards.get();
}

/// @notice Checks if a specific reward token is enabled for an account and rewarded token.
/// @param account The address of the account.
/// @param rewarded The address of the rewarded token.
/// @param reward The address of the reward token to check if enabled.
/// @return Whether the reward token is enabled for the account and rewarded token.
function isRewardEnabled(
address account,
address rewarded,
address reward
) external view virtual override returns (bool) {
return accounts[account][rewarded].enabledRewards.contains(reward);
}

/// @notice Returns the rewarded token balance of a specific account.
/// @param account The address of the account.
/// @param rewarded The address of the rewarded token.
Expand Down Expand Up @@ -502,7 +535,13 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
/// @param rewarded The address of the rewarded token.
/// @param reward The address of the reward token.
/// @param recipient The address to which the claimed reward will be transferred.
function claim(address account, address rewarded, address reward, address recipient) internal virtual {
/// @return The amount of the claimed reward tokens.
function claim(
address account,
address rewarded,
address reward,
address recipient
) internal virtual returns (uint256) {
EarnStorage storage accountEarned = accounts[account][rewarded].earned[reward];
uint128 amount = accountEarned.claimable;

Expand All @@ -521,6 +560,8 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard
pushToken(IERC20(reward), recipient, amount);
emit RewardClaimed(account, rewarded, reward, amount);
}

return amount;
}

/// @notice Updates the data for a specific account, rewarded token and reward token.
Expand Down
9 changes: 5 additions & 4 deletions src/interfaces/IRewardStreams.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ interface IRewardStreams {
function MAX_DISTRIBUTION_LENGTH() external view returns (uint256);
function MAX_REWARDS_ENABLED() external view returns (uint256);
function registerReward(address rewarded, address reward, uint48 startEpoch, uint128[] calldata rewardAmounts) external;
function updateReward(address rewarded, address reward, address recipient) external;
function claimReward(address rewarded, address reward, address recipient, bool forfeitRecentReward) external;
function enableReward(address rewarded, address reward) external;
function disableReward(address rewarded, address reward, bool forfeitRecentReward) external;
function updateReward(address rewarded, address reward, address recipient) external returns (uint256);
function claimReward(address rewarded, address reward, address recipient, bool forfeitRecentReward) external returns (uint256);
function enableReward(address rewarded, address reward) external returns (bool);
function disableReward(address rewarded, address reward, bool forfeitRecentReward) external returns (bool);
function earnedReward(address account, address rewarded, address reward, bool forfeitRecentReward) external view returns (uint256);
function enabledRewards(address account, address rewarded) external view returns (address[] memory);
function isRewardEnabled(address account, address rewarded, address reward) external view returns (bool);
function balanceOf(address account, address rewarded) external view returns (uint256);
function rewardAmount(address rewarded, address reward) external view returns (uint256);
function totalRewardedEligible(address rewarded, address reward) external view returns (uint256);
Expand Down
4 changes: 0 additions & 4 deletions test/harness/BaseRewardStreamsHarness.sol
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ contract BaseRewardStreamsHarness is BaseRewardStreams {
accounts[account][rewarded].enabledRewards.insert(reward);
}

function isRewardEnabled(address account, address rewarded, address reward) external view returns (bool) {
return accounts[account][rewarded].enabledRewards.contains(reward);
}

function getAccountEarnedData(
address account,
address rewarded,
Expand Down
4 changes: 0 additions & 4 deletions test/harness/StakingRewardStreamsHarness.sol
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ contract StakingRewardStreamsHarness is StakingRewardStreams {
accounts[account][rewarded].enabledRewards.insert(reward);
}

function isRewardEnabled(address account, address rewarded, address reward) external view returns (bool) {
return accounts[account][rewarded].enabledRewards.contains(reward);
}

function getAccountEarnedData(
address account,
address rewarded,
Expand Down
4 changes: 0 additions & 4 deletions test/harness/TrackingRewardStreamsHarness.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ contract TrackingRewardStreamsHarness is TrackingRewardStreams {
accounts[account][rewarded].enabledRewards.insert(reward);
}

function isRewardEnabled(address account, address rewarded, address reward) external view returns (bool) {
return accounts[account][rewarded].enabledRewards.contains(reward);
}

function getAccountEarnedData(
address account,
address rewarded,
Expand Down
30 changes: 20 additions & 10 deletions test/unit/Scenarios.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ contract ScenarioTest is Test {

// claim the rewards earned by address(0)
uint256 preClaimBalance = MockERC20(reward).balanceOf(address(this));
stakingDistributor.updateReward(stakingRewarded, reward, address(this));
uint256 claimedAmount = stakingDistributor.updateReward(stakingRewarded, reward, address(this));
assertEq(MockERC20(reward).balanceOf(address(this)), preClaimBalance + expectedAmount);
assertEq(claimedAmount, expectedAmount);

preClaimBalance = MockERC20(reward).balanceOf(address(this));
trackingDistributor.updateReward(trackingRewarded, reward, address(this));
claimedAmount = trackingDistributor.updateReward(trackingRewarded, reward, address(this));
assertEq(MockERC20(reward).balanceOf(address(this)), preClaimBalance + expectedAmount);
assertEq(claimedAmount, expectedAmount);

// verify total claimed
assertEq(stakingDistributor.totalRewardClaimed(stakingRewarded, reward), expectedAmount);
Expand All @@ -127,22 +129,26 @@ contract ScenarioTest is Test {
);

// only update the rewards
stakingDistributor.updateReward(stakingRewarded, reward, address(0));
claimedAmount = stakingDistributor.updateReward(stakingRewarded, reward, address(0));
preClaimBalance = MockERC20(reward).balanceOf(address(this));
assertEq(MockERC20(reward).balanceOf(address(this)), preClaimBalance);
assertEq(claimedAmount, 0);

trackingDistributor.updateReward(trackingRewarded, reward, address(0));
claimedAmount = trackingDistributor.updateReward(trackingRewarded, reward, address(0));
preClaimBalance = MockERC20(reward).balanceOf(address(this));
assertEq(MockERC20(reward).balanceOf(address(this)), preClaimBalance);
assertEq(claimedAmount, 0);

// claim the rewards earned by address(0)
preClaimBalance = MockERC20(reward).balanceOf(address(this));
stakingDistributor.updateReward(stakingRewarded, reward, address(this));
claimedAmount = stakingDistributor.updateReward(stakingRewarded, reward, address(this));
assertApproxEqAbs(MockERC20(reward).balanceOf(address(this)), preClaimBalance + expectedAmount, 1);
assertApproxEqAbs(claimedAmount, expectedAmount, 1);

preClaimBalance = MockERC20(reward).balanceOf(address(this));
trackingDistributor.updateReward(trackingRewarded, reward, address(this));
claimedAmount = trackingDistributor.updateReward(trackingRewarded, reward, address(this));
assertApproxEqAbs(MockERC20(reward).balanceOf(address(this)), preClaimBalance + expectedAmount, 1);
assertApproxEqAbs(claimedAmount, expectedAmount, 1);

// verify total claimed
assertApproxEqAbs(stakingDistributor.totalRewardClaimed(stakingRewarded, reward), totalAmount, 1);
Expand Down Expand Up @@ -233,15 +239,17 @@ contract ScenarioTest is Test {
stakingDistributor.updateReward(stakingRewarded, reward, address(0));
assertEq(MockERC20(reward).balanceOf(PARTICIPANT_1), preClaimBalance);

stakingDistributor.claimReward(stakingRewarded, reward, PARTICIPANT_1, false);
uint256 claimedAmount = stakingDistributor.claimReward(stakingRewarded, reward, PARTICIPANT_1, false);
assertApproxEqRel(MockERC20(reward).balanceOf(PARTICIPANT_1), preClaimBalance + expectedAmount, ALLOWED_DELTA);
assertApproxEqRel(claimedAmount, expectedAmount, ALLOWED_DELTA);

preClaimBalance = MockERC20(reward).balanceOf(PARTICIPANT_1);
trackingDistributor.updateReward(trackingRewarded, reward, address(0));
assertEq(MockERC20(reward).balanceOf(PARTICIPANT_1), preClaimBalance);

trackingDistributor.claimReward(trackingRewarded, reward, PARTICIPANT_1, false);
claimedAmount = trackingDistributor.claimReward(trackingRewarded, reward, PARTICIPANT_1, false);
assertApproxEqRel(MockERC20(reward).balanceOf(PARTICIPANT_1), preClaimBalance + expectedAmount, ALLOWED_DELTA);
assertApproxEqRel(claimedAmount, expectedAmount, ALLOWED_DELTA);

// verify total claimed
assertApproxEqRel(stakingDistributor.totalRewardClaimed(stakingRewarded, reward), expectedAmount, ALLOWED_DELTA);
Expand Down Expand Up @@ -271,12 +279,14 @@ contract ScenarioTest is Test {

// claim the rewards earned by the participant (will be transferred to this contract)
preClaimBalance = MockERC20(reward).balanceOf(address(this));
stakingDistributor.claimReward(stakingRewarded, reward, address(this), false);
claimedAmount = stakingDistributor.claimReward(stakingRewarded, reward, address(this), false);
assertApproxEqRel(MockERC20(reward).balanceOf(address(this)), preClaimBalance + expectedAmount, ALLOWED_DELTA);
assertApproxEqRel(claimedAmount, expectedAmount, ALLOWED_DELTA);

preClaimBalance = MockERC20(reward).balanceOf(address(this));
trackingDistributor.claimReward(trackingRewarded, reward, address(this), false);
claimedAmount = trackingDistributor.claimReward(trackingRewarded, reward, address(this), false);
assertApproxEqRel(MockERC20(reward).balanceOf(address(this)), preClaimBalance + expectedAmount, ALLOWED_DELTA);
assertApproxEqRel(claimedAmount, expectedAmount, ALLOWED_DELTA);

// verify total claimed
assertApproxEqRel(stakingDistributor.totalRewardClaimed(stakingRewarded, reward), totalAmount, ALLOWED_DELTA);
Expand Down
26 changes: 26 additions & 0 deletions test/unit/View.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@ contract ViewTest is Test {
}
}

function test_IsRewardEnabled(address account, address rewarded, uint8 n, uint256 index, uint256 seed) external {
account = boundAddr(account);
rewarded = boundAddr(rewarded);
n = uint8(bound(n, 1, 5));
index = uint8(bound(index, 0, n - 1));

vm.startPrank(account);
for (uint8 i = 0; i < n; i++) {
address reward = address(uint160(uint256(keccak256(abi.encode(seed, i)))));
bool wasEnabled = distributor.enableReward(rewarded, reward);

assertTrue(distributor.isRewardEnabled(account, rewarded, reward));
assertTrue(wasEnabled);
assertFalse(distributor.enableReward(rewarded, reward));
}

address[] memory enabledRewards = distributor.enabledRewards(account, rewarded);
assertEq(enabledRewards.length, n);

bool wasDisabled = distributor.disableReward(rewarded, enabledRewards[index], false);

assertFalse(distributor.isRewardEnabled(account, rewarded, enabledRewards[index]));
assertTrue(wasDisabled);
assertFalse(distributor.disableReward(rewarded, enabledRewards[index], false));
}

function test_BalanceOf(address account, address rewarded, uint256 balance) external {
distributor.setAccountBalance(account, rewarded, balance);
assertEq(distributor.balanceOf(account, rewarded), balance);
Expand Down

0 comments on commit 014e65e

Please sign in to comment.