diff --git a/wallet/src/address.rs b/wallet/src/address.rs index dbf6587..9b3c51c 100644 --- a/wallet/src/address.rs +++ b/wallet/src/address.rs @@ -74,6 +74,10 @@ impl AddressEntry { } pub fn is_custodian(&self) -> bool { + self.role == Role::Custodian + } + + pub fn is_controller_or_custodian(&self) -> bool { self.role == Role::Controller || self.role == Role::Custodian } } @@ -136,6 +140,12 @@ impl AddressBook { self.find(principal).map_or(false, |e| e.is_controller()) } + #[inline] + pub fn is_controller_or_custodian(&self, principal: &Principal) -> bool { + self.find(principal) + .map_or(false, |e| e.is_controller_or_custodian()) + } + #[inline] pub fn custodians(&self) -> impl Iterator { self.iter().filter(|e| e.is_custodian()) diff --git a/wallet/src/lib.did b/wallet/src/lib.did index 3d604aa..8e92b44 100644 --- a/wallet/src/lib.did +++ b/wallet/src/lib.did @@ -56,17 +56,17 @@ type AddressEntry = record { role: Role; }; -type ResultCreate = variant { +type WalletResultCreate = variant { Ok : record { canister_id: principal }; Err: text; }; -type ResultSend = variant { +type WalletResult = variant { Ok : null; Err : text; }; -type ResultCall = variant { +type WalletResultCall = variant { Ok : record { return: blob }; Err : text; }; @@ -91,22 +91,22 @@ service : { // Controller Management get_controllers: () -> (vec principal) query; add_controller: (principal) -> (); - remove_controller: (principal) -> (); + remove_controller: (principal) -> (WalletResult); // Custodian Management get_custodians: () -> (vec principal) query; authorize: (principal) -> (); - deauthorize: (principal) -> (); + deauthorize: (principal) -> (WalletResult); // Cycle Management wallet_balance: () -> (record { amount: nat64 }) query; - wallet_send: (record { canister: principal; amount: nat64 }) -> (ResultSend); + wallet_send: (record { canister: principal; amount: nat64 }) -> (WalletResult); wallet_receive: () -> (); // Endpoint for receiving cycles. // Managing canister - wallet_create_canister: (CreateCanisterArgs) -> (ResultCreate); + wallet_create_canister: (CreateCanisterArgs) -> (WalletResultCreate); - wallet_create_wallet: (CreateCanisterArgs) -> (ResultCreate); + wallet_create_wallet: (CreateCanisterArgs) -> (WalletResultCreate); wallet_store_wallet_wasm: (record { wasm_module: blob; @@ -118,12 +118,12 @@ service : { method_name: text; args: blob; cycles: nat64; - }) -> (ResultCall); + }) -> (WalletResultCall); // Address book add_address: (address: AddressEntry) -> (); list_addresses: () -> (vec AddressEntry) query; - remove_address: (address: principal) -> (); + remove_address: (address: principal) -> (WalletResult); // Events get_events: (opt record { from: opt nat32; to: opt nat32; }) -> (vec Event) query; diff --git a/wallet/src/lib.rs b/wallet/src/lib.rs index 7c23454..4377c26 100644 --- a/wallet/src/lib.rs +++ b/wallet/src/lib.rs @@ -87,7 +87,7 @@ fn post_upgrade() { /*************************************************************************************************** * Wallet Name **************************************************************************************************/ -#[query(guard = "is_custodian")] +#[query(guard = "is_custodian_or_controller")] fn name() -> Option { storage::get::().0.clone() } @@ -108,7 +108,7 @@ include!(concat!(env!("OUT_DIR"), "/http_request.rs")); **************************************************************************************************/ /// Get the controller of this canister. -#[query(guard = "is_custodian")] +#[query(guard = "is_custodian_or_controller")] fn get_controllers() -> Vec<&'static Principal> { storage::get_mut::() .controllers() @@ -125,14 +125,25 @@ fn add_controller(controller: Principal) { /// Remove a controller. This is equivalent to moving the role to a regular user. #[update(guard = "is_controller")] -fn remove_controller(controller: Principal) { - let book = storage::get_mut::(); +fn remove_controller(controller: Principal) -> Result<(), String> { + if !storage::get::().is_controller(&controller) { + return Err(format!( + "Cannot remove {} because it is not a controller.", + controller.to_text() + )); + } + if storage::get::().controllers().count() > 1 { + let book = storage::get_mut::(); - if let Some(mut entry) = book.take(&controller) { - entry.role = Role::Contact; - book.insert(entry); + if let Some(mut entry) = book.take(&controller) { + entry.role = Role::Contact; + book.insert(entry); + } + update_chart(); + Ok(()) + } else { + Err("The wallet must have at least one controller.".to_string()) } - update_chart(); } /*************************************************************************************************** @@ -140,7 +151,7 @@ fn remove_controller(controller: Principal) { **************************************************************************************************/ /// Get the custodians of this canister. -#[query(guard = "is_custodian")] +#[query(guard = "is_custodian_or_controller")] fn get_custodians() -> Vec<&'static Principal> { storage::get::() .custodians() @@ -157,13 +168,21 @@ fn authorize(custodian: Principal) { /// Deauthorize a custodian. #[update(guard = "is_controller")] -fn deauthorize(custodian: Principal) { - remove_address(custodian); - update_chart(); +fn deauthorize(custodian: Principal) -> Result<(), String> { + if storage::get::().is_custodian(&custodian) { + remove_address(custodian)?; + update_chart(); + Ok(()) + } else { + Err(format!( + "Cannot deauthorize {} as it is not a custodian.", + custodian.to_text() + )) + } } mod wallet { - use crate::{events, is_custodian}; + use crate::{events, is_custodian_or_controller}; use ic_cdk::export::candid::{CandidType, Nat}; use ic_cdk::export::Principal; use ic_cdk::{api, caller, id, storage}; @@ -185,7 +204,7 @@ mod wallet { } /// Return the cycle balance of this canister. - #[query(guard = "is_custodian", name = "wallet_balance")] + #[query(guard = "is_custodian_or_controller", name = "wallet_balance")] fn balance() -> BalanceResult { BalanceResult { amount: api::canister_balance() as u64, @@ -193,7 +212,7 @@ mod wallet { } /// Send cycles to another canister. - #[update(guard = "is_custodian", name = "wallet_send")] + #[update(guard = "is_custodian_or_controller", name = "wallet_send")] async fn send(args: SendCyclesArgs) -> Result<(), String> { match api::call::call_with_payment( args.canister.clone(), @@ -276,7 +295,7 @@ mod wallet { canister_id: Principal, } - #[update(guard = "is_custodian", name = "wallet_create_canister")] + #[update(guard = "is_custodian_or_controller", name = "wallet_create_canister")] async fn create_canister(args: CreateCanisterArgs) -> Result { let create_result = create_canister_call(args).await?; @@ -429,7 +448,7 @@ mod wallet { Ok(()) } - #[update(guard = "is_custodian", name = "wallet_create_wallet")] + #[update(guard = "is_custodian_or_controller", name = "wallet_create_wallet")] async fn create_wallet(args: CreateCanisterArgs) -> Result { let wallet_bytes = storage::get::(); let wasm_module = match &wallet_bytes.0 { @@ -503,7 +522,7 @@ mod wallet { } /// Forward a call to another canister. - #[update(guard = "is_custodian", name = "wallet_call")] + #[update(guard = "is_custodian_or_controller", name = "wallet_call")] async fn call(args: CallCanisterArgs) -> Result { if api::id() == caller() { return Err("Attempted to call forward on self. This is not allowed. Call this method via a different custodian.".to_string()); @@ -550,16 +569,23 @@ fn add_address(address: AddressEntry) { update_chart(); } -#[query(guard = "is_custodian")] +#[query(guard = "is_custodian_or_controller")] fn list_addresses() -> Vec<&'static AddressEntry> { storage::get::().iter().collect() } #[update(guard = "is_controller")] -fn remove_address(address: Principal) { - storage::get_mut::().remove(&address); - update_chart(); - record(EventKind::AddressRemoved { id: address }) +fn remove_address(address: Principal) -> Result<(), String> { + if storage::get::().is_controller(&address) + && storage::get::().controllers().count() == 1 + { + Err("The wallet must have at least one controller.".to_string()) + } else { + storage::get_mut::().remove(&address); + record(EventKind::AddressRemoved { id: address }); + update_chart(); + Ok(()) + } } /*************************************************************************************************** @@ -573,7 +599,7 @@ struct GetEventsArgs { } /// Return the recent events observed by this canister. -#[query(guard = "is_custodian")] +#[query(guard = "is_custodian_or_controller")] fn get_events(args: Option) -> &'static [Event] { if let Some(GetEventsArgs { from, to }) = args { events::get_events(from, to) @@ -597,7 +623,7 @@ struct GetChartArgs { precision: Option, } -#[query(guard = "is_custodian")] +#[query(guard = "is_custodian_or_controller")] fn get_chart(args: Option) -> Vec<(u64, u64)> { let chart = storage::get_mut::>(); @@ -648,11 +674,11 @@ fn is_controller() -> Result<(), String> { } /// Check if the caller is a custodian. -fn is_custodian() -> Result<(), String> { +fn is_custodian_or_controller() -> Result<(), String> { let caller = &caller(); - if storage::get::().is_custodian(caller) || &api::id() == caller { + if storage::get::().is_controller_or_custodian(caller) || &api::id() == caller { Ok(()) } else { - Err("Only a custodian can call this method.".to_string()) + Err("Only a controller or custodian can call this method.".to_string()) } }