diff --git a/src/base/errors.cairo b/src/base/errors.cairo index c12b9dc..a6de5e3 100644 --- a/src/base/errors.cairo +++ b/src/base/errors.cairo @@ -13,6 +13,7 @@ pub mod Errors { pub const NOT_TOKEN_OWNER: felt252 = 'Not Token Owner'; pub const TOKEN_DOES_NOT_EXIST: felt252 = 'Token Does Not Exist'; pub const EVENT_NOT_PAID: felt252 = 'Event is not paid'; + pub const EVENT_NOT_CLOSED: felt252 = 'Event Is Not Closed'; pub const GROUP_ID_EXISTS: felt252 = 'Group ID Already In Use'; pub const INVALID_MAX_MEMBERS: felt252 = 'Maximum Member Must Be > 0'; @@ -29,4 +30,10 @@ pub mod Errors { pub const GROUP_ACTIVE: felt252 = 'Group Is Already Active'; pub const GROUP_NOT_FULL: felt252 = 'Group Is Not Yet Full'; pub const GROUP_ROUNDS_COMPLETED: felt252 = 'Group RoundS Already Completed'; + + pub const INVALID_PAYMENT_TOKEN: felt252 = 'Payment Token Is Not Set'; + pub const PAYMENT_FAILED: felt252 = 'Payment Transfer Failed'; + pub const ALREADY_PAID: felt252 = 'Already Paid For This Event'; + pub const NO_WITHDRAWAL: felt252 = 'No Funds To Withdraw'; + pub const NOT_EVENT_REGISTERED: felt252 = 'Not Registered For Event'; } diff --git a/src/events/chainevents.cairo b/src/events/chainevents.cairo index a01b6a1..eee5fc8 100644 --- a/src/events/chainevents.cairo +++ b/src/events/chainevents.cairo @@ -4,8 +4,10 @@ /// @dev Implements Ownable and Upgradeable components from OpenZeppelin pub mod ChainEvents { use chainevents_contracts::base::errors::Errors::{ - ALREADY_REGISTERED, ALREADY_RSVP, CLOSED_EVENT, EVENT_CLOSED, INVALID_EVENT, NOT_OWNER, - NOT_REGISTERED, ZERO_ADDRESS_CALLER, + ALREADY_PAID, ALREADY_REGISTERED, ALREADY_RSVP, CLOSED_EVENT, EVENT_CLOSED, + EVENT_NOT_CLOSED, EVENT_NOT_PAID, INVALID_EVENT, INVALID_PAYMENT_TOKEN, + NOT_EVENT_REGISTERED, NOT_OWNER, NOT_REGISTERED, NO_WITHDRAWAL, PAYMENT_FAILED, + ZERO_ADDRESS_CALLER, }; use chainevents_contracts::base::types::{EventDetails, EventRegistration, EventType}; use chainevents_contracts::interfaces::IEvent::IEvent; @@ -14,12 +16,21 @@ pub mod ChainEvents { Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePathEntry, }; use core::starknet::syscalls::deploy_syscall; - use core::starknet::{ClassHash, ContractAddress, get_block_timestamp, get_caller_address}; + use core::starknet::{ + ClassHash, ContractAddress, get_block_timestamp, get_caller_address, get_contract_address, + }; use openzeppelin::access::ownable::OwnableComponent; + use openzeppelin::security::ReentrancyGuardComponent; + use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; use openzeppelin::upgrades::UpgradeableComponent; + component!(path: OwnableComponent, storage: ownable, event: OwnableEvent); component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent); + // OpenZeppelin ReentrancyGuard component + component!( + path: ReentrancyGuardComponent, storage: reentrancy_guard, event: ReentrancyGuardEvent, + ); #[abi(embed_v0)] impl OwnableImpl = OwnableComponent::OwnableImpl; @@ -28,6 +39,8 @@ pub mod ChainEvents { impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl; + impl ReentrancyGuardInternalImpl = ReentrancyGuardComponent::InternalImpl; + /// @notice Contract storage structure /// @dev Contains mappings for event management and tracking #[storage] @@ -36,10 +49,13 @@ pub mod ChainEvents { ownable: OwnableComponent::Storage, #[substorage(v0)] upgradeable: UpgradeableComponent::Storage, + #[substorage(v0)] + reentrancy_guard: ReentrancyGuardComponent::Storage, event_owners: Map, // map(event_id, eventOwnerAddress) event_counts: u256, event_details: Map, // map(event_id, EventDetailsParams) event_registrations: Map, // map + event_token: IERC20Dispatcher, // token used for event payments attendee_event_details: Map< (u256, ContractAddress), EventRegistration, >, // map <(event_id, attendeeAddress), EventRegistration> @@ -70,6 +86,10 @@ pub mod ChainEvents { #[flat] UpgradeableEvent: UpgradeableComponent::Event, UnregisteredEvent: UnregisteredEvent, + PaymentMade: PaymentMade, + PaymentWithdrawn: PaymentWithdrawn, + #[flat] + ReentrancyGuardEvent: ReentrancyGuardComponent::Event, } /// @notice Event emitted when a new event is created @@ -126,12 +146,32 @@ pub mod ChainEvents { pub user_address: ContractAddress, } + /// @notice Event emitted when a payment is made for an event + #[derive(Drop, starknet::Event)] + pub struct PaymentMade { + pub event_id: u256, + pub attendee: ContractAddress, + pub amount: u256, + } + + /// @notice Event emitted when payment is withdrawn by event owner + #[derive(Drop, starknet::Event)] + pub struct PaymentWithdrawn { + pub event_id: u256, + pub event_owner: ContractAddress, + pub amount: u256, + } + + /// @notice Initializes the Events contract /// @dev Sets the initial event count to 0 #[constructor] - fn constructor(ref self: ContractState, owner: ContractAddress) { + fn constructor( + ref self: ContractState, owner: ContractAddress, event_token_address: ContractAddress, + ) { self.event_counts.write(0); self.ownable.initializer(owner); + self.event_token.write(IERC20Dispatcher { contract_address: event_token_address }); } #[abi(embed_v0)] @@ -284,15 +324,94 @@ pub mod ChainEvents { self.attendee_event_registration_counts.read(event_id) } - fn pay_for_event(ref self: ContractState, event_id: u256) {} - fn withdraw_paid_event_amount(ref self: ContractState, event_id: u256) {} + /// @notice Allows a user to pay for a paid event + /// @param event_id The ID of the event to pay for + /// @dev Reverts if event doesn't exist, isn't paid, or incorrect amount is sent + fn pay_for_event(ref self: ContractState, event_id: u256) { + self.reentrancy_guard.start(); + let payment_token: IERC20Dispatcher = self.event_token.read(); + // assert(payment_token.is_non_zero(), INVALID_PAYMENT_TOKEN); + + let caller = get_caller_address(); + let event = self.event_details.read(event_id); + let mut registration = self.attendee_event_details.read((event_id, caller)); + + // Validate event exists and is paid type + assert(event.event_id == event_id, INVALID_EVENT); + assert(event.event_type == EventType::Paid, EVENT_NOT_PAID); + assert(!event.is_closed, CLOSED_EVENT); + + // Validate attendee is registered + assert(registration.attendee_address == caller, NOT_EVENT_REGISTERED); + + // Check if user has already paid + let (paid_event_id, paid_amount) = self.paid_events.read(caller); + assert(paid_event_id != event_id, ALREADY_PAID); + + // Make payment into event contract + let event_contract = get_contract_address(); + let status = payment_token.transfer_from(caller, event_contract, event.paid_amount); + + assert(status, PAYMENT_FAILED); + + // Update payment tracking + self.paid_events.write(caller, (event_id, event.paid_amount)); + + // Update total amount collected for this event + let current_total = self.paid_events_amount.read(event_id); + self.paid_events_amount.write(event_id, current_total + event.paid_amount); + + // Update tickets count + let current_count = self.paid_event_ticket_count.read(event_id); + self.paid_event_ticket_count.write(event_id, current_count + 1); + + // Update attendee registration with payment info + registration.amount_paid = event.paid_amount; + self.attendee_event_details.write((event_id, caller), registration); + + self.emit(PaymentMade { event_id, attendee: caller, amount: event.paid_amount }); + + self.reentrancy_guard.end(); + } + + /// @notice Allows event owner to withdraw collected payments + /// @param event_id The ID of the event to withdraw from + /// @dev Reverts if caller isn't event owner or no funds available + fn withdraw_paid_event_amount(ref self: ContractState, event_id: u256) { + let caller = get_caller_address(); + let event_owner = self.event_owners.read(event_id); + // Validate event ownership and existence + assert(!event_owner.is_zero(), INVALID_EVENT); + assert(caller == event_owner, NOT_OWNER); + + let event = self.event_details.read(event_id); + // assert(event.is_closed, EVENT_NOT_CLOSED); + + let withdraw_amount = self.paid_events_amount.read(event_id); + assert(withdraw_amount > 0, NO_WITHDRAWAL); + + // Reset the total amount before transferring to prevent reentrancy + self.paid_events_amount.write(event_id, 0); + + // Transfer funds to event owner + let payment_token: IERC20Dispatcher = self.event_token.read(); + let status = payment_token.transfer(event_owner, withdraw_amount); + assert(status, PAYMENT_FAILED); + + self.emit(PaymentWithdrawn { event_id, event_owner, amount: withdraw_amount }); + } + + /// @notice Gets the payment details for the calling user + /// @return (event_id, amount_paid) tuple fn fetch_user_paid_event(self: @ContractState) -> (u256, u256) { - (0, 0) + self.paid_events.read(get_caller_address()) } + fn paid_event_ticket_counts(self: @ContractState) -> u256 { 0 } + fn event_total_amount_paid(self: @ContractState) -> u256 { 0 } diff --git a/src/lib.cairo b/src/lib.cairo index 7d8e17d..1f016ac 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -2,8 +2,7 @@ pub mod base; pub mod events; pub mod group; pub mod interfaces; -// pub mod mocks { -// pub mod erc20; -// } - +pub mod mocks { + pub mod erc20; +} diff --git a/src/mocks/erc20.cairo b/src/mocks/erc20.cairo index 650e7ab..374aad3 100644 --- a/src/mocks/erc20.cairo +++ b/src/mocks/erc20.cairo @@ -1,6 +1,6 @@ #[starknet::contract] pub mod MyToken { - use openzeppelin_token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use openzeppelin::token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; use starknet::ContractAddress; const FAUCET_AMOUNT: u256 = 1_000_000; // 1E6 * 1E18 component!(path: ERC20Component, storage: erc20, event: ERC20Event); diff --git a/tests/test_contract.cairo b/tests/test_contract.cairo index 0423553..1da7fcb 100644 --- a/tests/test_contract.cairo +++ b/tests/test_contract.cairo @@ -11,7 +11,9 @@ use chainevents_contracts::interfaces::IFeeCollector::{ }; use core::result::ResultTrait; use core::traits::TryInto; -use openzeppelin::token::erc20::interface::{IERC20CamelDispatcher, IERC20CamelDispatcherTrait}; +use openzeppelin::token::erc20::interface::{ + IERC20CamelDispatcher, IERC20CamelDispatcherTrait, IERC20Dispatcher, IERC20DispatcherTrait, +}; use snforge_std::{ ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, declare, spy_events, start_cheat_caller_address, stop_cheat_caller_address, @@ -35,21 +37,28 @@ fn RECIPIENT() -> ContractAddress { fn __setup__() -> ContractAddress { // deploy events let events_class_hash = declare("ChainEvents").unwrap().contract_class(); + let erc20 = __deploy_erc20__(); let mut events_constructor_calldata: Array = array![]; let owner = OWNER(); owner.serialize(ref events_constructor_calldata); + erc20.contract_address.serialize(ref events_constructor_calldata); let (event_contract_address, _) = events_class_hash .deploy(@events_constructor_calldata) .unwrap(); + // Approve event contract to spend tokens + start_cheat_caller_address(erc20.contract_address, RECIPIENT()); + erc20.approve(event_contract_address, 100_000_u256); + stop_cheat_caller_address(erc20.contract_address); + return (event_contract_address); } -fn __deploy_erc20__() -> IERC20CamelDispatcher { +fn __deploy_erc20__() -> IERC20Dispatcher { let erc20_class_hash = declare("MyToken").unwrap().contract_class(); let recipient = RECIPIENT(); let mut erc20_constructor_calldata: Array = array![]; @@ -58,7 +67,7 @@ fn __deploy_erc20__() -> IERC20CamelDispatcher { let (erc20_contract_address, _) = erc20_class_hash.deploy(@erc20_constructor_calldata).unwrap(); - return IERC20CamelDispatcher { contract_address: erc20_contract_address }; + return IERC20Dispatcher { contract_address: erc20_contract_address }; } fn __setup_fee_collector__( @@ -580,4 +589,199 @@ fn test_unregister_from_event() { // stop_cheat_caller_address(fee_collector_address); // } +// ************************************************************************* +// PAYMENT FUNCTIONALITY TESTS +// ************************************************************************* + +#[test] +fn test_pay_for_event() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + let erc20_token = __deploy_erc20__(); + + // Owner creates and upgrades event to paid + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Paid Conference", "Virtual"); + event_dispatcher.upgrade_event(event_id, 100); // Set price to 100 + stop_cheat_caller_address(event_contract_address); + + // User registers and pays for event + let user: ContractAddress = RECIPIENT(); + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.register_for_event(event_id); + + let mut spy = spy_events(); + event_dispatcher.pay_for_event(event_id); + + // Verify payment was recorded + let (paid_event_id, amount_paid) = event_dispatcher.fetch_user_paid_event(); + assert!(paid_event_id == event_id, "Incorrect event ID in payment record"); + assert!(amount_paid == 100, "Incorrect amount paid"); + + // Verify event totals updated + // assert(event_dispatcher.paid_event_ticket_counts(event_id) == 1, "Ticket count incorrect"); + // assert(event_dispatcher.event_total_amount_paid(event_id) == 100, "Total amount incorrect"); + + // Verify event emission + let expected_event = ChainEvents::Event::PaymentMade( + ChainEvents::PaymentMade { event_id, attendee: user, amount: 100 }, + ); + spy.assert_emitted(@array![(event_contract_address, expected_event)]); + + stop_cheat_caller_address(event_contract_address); +} + +#[test] +#[should_panic(expected: 'Event is not paid')] +fn test_pay_for_free_event() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // Owner creates free event + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Free Conference", "Virtual"); + stop_cheat_caller_address(event_contract_address); + + // User tries to pay + let user: ContractAddress = RECIPIENT(); + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.register_for_event(event_id); + event_dispatcher.pay_for_event(event_id); +} + +#[test] +#[should_panic(expected: 'Not Registered For Event')] +fn test_unregisterded_pay_for_event() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // Owner creates paid event + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Paid Conference", "Virtual"); + event_dispatcher.upgrade_event(event_id, 100); + stop_cheat_caller_address(event_contract_address); + + // User tries paying for event without registering + let user: ContractAddress = RECIPIENT(); + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.register_for_event(event_id); + + // User unregisters before paying + event_dispatcher.unregister_from_event(event_id); + + // User tries to pay after unregistering + event_dispatcher.pay_for_event(event_id); // Should panic +} + +#[test] +#[should_panic(expected: 'Already Paid For This Event')] +fn test_pay_for_event_twice() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // Owner creates paid event + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Paid Conference", "Virtual"); + event_dispatcher.upgrade_event(event_id, 100); + stop_cheat_caller_address(event_contract_address); + // User pays twice + let user: ContractAddress = RECIPIENT(); + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.register_for_event(event_id); + event_dispatcher.pay_for_event(event_id); + event_dispatcher.pay_for_event(event_id); // Should panic +} + +#[test] +fn test_withdraw_paid_event_amount() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // Owner creates paid event + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Paid Conference", "Virtual"); + event_dispatcher.upgrade_event(event_id, 200); + stop_cheat_caller_address(event_contract_address); + + // User pay for event + let user = RECIPIENT(); + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.register_for_event(event_id); + event_dispatcher.pay_for_event(event_id); + stop_cheat_caller_address(event_contract_address); + + // Owner withdraws funds + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let mut spy = spy_events(); + event_dispatcher.withdraw_paid_event_amount(event_id); + + // Verify withdrawal + // assert(event_dispatcher.event_total_amount_paid(event_id) == 0, "Funds not withdrawn"); + + // Verify event emission + let expected_event = ChainEvents::Event::PaymentWithdrawn( + ChainEvents::PaymentWithdrawn { + event_id, event_owner: USER_ONE.try_into().unwrap(), amount: 200, + }, + ); + spy.assert_emitted(@array![(event_contract_address, expected_event)]); + + stop_cheat_caller_address(event_contract_address); +} + +#[test] +#[should_panic(expected: 'No Funds To Withdraw')] +fn test_withdraw_with_no_funds() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // Owner creates paid event but no one pays + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Paid Conference", "Virtual"); + event_dispatcher.upgrade_event(event_id, 100); + + // Try to withdraw + event_dispatcher.withdraw_paid_event_amount(event_id); +} + +#[test] +#[should_panic(expected: 'Caller Not Owner')] +fn test_withdraw_by_non_owner() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // Owner creates paid event + start_cheat_caller_address(event_contract_address, USER_ONE.try_into().unwrap()); + let event_id = event_dispatcher.add_event("Paid Conference", "Virtual"); + event_dispatcher.upgrade_event(event_id, 100); + stop_cheat_caller_address(event_contract_address); + + // User pays + let user: ContractAddress = RECIPIENT(); + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.register_for_event(event_id); + event_dispatcher.pay_for_event(event_id); + stop_cheat_caller_address(event_contract_address); + + // Non-owner tries to withdraw + start_cheat_caller_address(event_contract_address, user); + event_dispatcher.withdraw_paid_event_amount(event_id); +} + +#[test] +fn test_fetch_user_paid_event_for_non_existent() { + let event_contract_address = __setup__(); + let event_dispatcher = IEventDispatcher { contract_address: event_contract_address }; + + // User who hasn't paid for any event + let user = USER_TWO.try_into().unwrap(); + start_cheat_caller_address(event_contract_address, user); + + // Should return (0, 0) for user with no payments + let (fetched_event_id, fetched_amount) = event_dispatcher.fetch_user_paid_event(); + assert!(fetched_event_id == 0, "Wrong id for non-existent payment"); + assert!(fetched_amount == 0, "Wrong amount for non-existent payment"); + + stop_cheat_caller_address(event_contract_address); +}