diff --git a/pytest/common_lib/shared/__init__.py b/pytest/common_lib/shared/__init__.py index d46ea7d05..caea830d2 100644 --- a/pytest/common_lib/shared/__init__.py +++ b/pytest/common_lib/shared/__init__.py @@ -25,16 +25,83 @@ create_create_account_action, create_payment_action, create_full_access_key_action, + # create_mpc_function_call_access_key_action sign_transaction, serialize_transaction, + Action, + AccessKey, + AccessKeyPermission, + FunctionCallPermission, + PublicKey, + AddKey, ) + from key import Key dot_near = pathlib.Path.home() / ".near" SECRETS_JSON = "secrets.json" +def create_function_call_access_key_action( + pk: bytes, contract_id: str, method_names: list[str], allowance: int | None = None +) -> "Action": + permission = AccessKeyPermission() + permission.enum = "functionCall" + + fc_perm = FunctionCallPermission() + fc_perm.allowance = allowance + fc_perm.receiverId = contract_id + fc_perm.methodNames = method_names + permission.functionCall = fc_perm + + access_key = AccessKey() + access_key.nonce = 0 + access_key.permission = permission + + public_key = PublicKey() + public_key.keyType = 0 + public_key.data = pk + + add_key = AddKey() + add_key.accessKey = access_key + add_key.publicKey = public_key + + action = Action() + action.enum = "addKey" + action.addKey = add_key + + return action + + +def create_mpc_function_call_access_key_action( + pk: bytes, contract_id: str, allowance: int | None = None +) -> "Action": + """ + Create a restricted access key that only allows calling MPC-related contract methods. + """ + mpc_methods_used_by_node = [ + "respond", + "respond_ckd", + "vote_pk", + "start_keygen_instance", + "vote_reshared", + "start_reshare_instance", + "vote_abort_key_event_instance", + "verify_tee", + "submit_participant_info", + "vote_add_domains", + "vote_cancel_resharing", + ] + + return create_function_call_access_key_action( + pk=pk, + contract_id=contract_id, + method_names=mpc_methods_used_by_node, + allowance=allowance, + ) + + # Output is deserializable into the rust type near_sdk::SecretKey def serialize_key(key: bytes) -> str: key_bytes = bytes(key) @@ -47,19 +114,54 @@ def deserialize_key(account_id: str, key: List[int]) -> Key: return Key.from_keypair(account_id, signing_key) +def dump(obj, indent=0): + pad = " " * indent + if hasattr(obj, "__dict__"): + for k, v in vars(obj).items(): + print(f"{pad}{k}:") + dump(v, indent + 1) + elif isinstance(obj, (list, tuple)): + for i, v in enumerate(obj): + print(f"{pad}[{i}]:") + dump(v, indent + 1) + else: + print(f"{pad}{obj}") + + def sign_create_account_with_multiple_access_keys_tx( creator_key: Key, new_account_id, keys: List[Key], nonce, block_hash, + contract_id, + fullAccess: bool, + createNewAccount: bool, ) -> bytes: - create_account_action = create_create_account_action() - payment_action = create_payment_action(100 * NEAR_BASE) - access_key_actions = [ - create_full_access_key_action(key.decoded_pk()) for key in keys - ] - actions = [create_account_action, payment_action] + access_key_actions + actions = [] + + if createNewAccount: + # Only when creating a brand-new account + actions.append(create_create_account_action()) + actions.append(create_payment_action(100 * NEAR_BASE)) + + if fullAccess: + # Give full access to all keys + access_key_actions = [ + create_full_access_key_action(key.decoded_pk()) for key in keys + ] + else: + # Give restricted MPC-only access to all keys + access_key_actions = [ + create_mpc_function_call_access_key_action( + key.decoded_pk(), contract_id, allowance=100 * NEAR_BASE + ) + for key in keys + ] + print("access key actions:", access_key_actions) + dump(access_key_actions) + actions.extend(access_key_actions) + signed_tx = sign_transaction( new_account_id, nonce, @@ -69,6 +171,12 @@ def sign_create_account_with_multiple_access_keys_tx( creator_key.decoded_pk(), creator_key.decoded_sk(), ) + + print("signed tx: ", signed_tx) + dump(signed_tx) + + # for name, value in vars(signed_tx).items(): + # print(name, "=", value) return serialize_transaction(signed_tx) @@ -267,10 +375,12 @@ def start_cluster_with_mpc( ) (key, nonce) = cluster.contract_node.get_key_and_nonce() - txs = [] + create_txs = [] + access_txs = [] mpc_nodes = [] + pytest_keys_per_node = [] for near_node, candidate in zip(observers, candidates): - # add the nodes access key to the list + # add the nodes responder access key to the list nonce += 1 tx = sign_create_account_with_multiple_access_keys_tx( key, @@ -278,8 +388,11 @@ def start_cluster_with_mpc( candidate.responder_keys, nonce, cluster.contract_node.last_block_hash(), + cluster.mpc_contract_account(), + True, + True, ) - txs.append(tx) + create_txs.append(tx) candidate_account_id = candidate.signer_key.account_id pytest_signer_keys = [] for i in range(0, 5): @@ -295,14 +408,57 @@ def start_cluster_with_mpc( nonce += 1 # Observer nodes haven't started yet so we use cluster node to send txs + # add pytest_signer_keys that are used for voting, need to access tx = sign_create_account_with_multiple_access_keys_tx( key, candidate_account_id, - [candidate.signer_key] + pytest_signer_keys, + pytest_signer_keys, nonce, cluster.contract_node.last_block_hash(), + cluster.mpc_contract_account(), + True, + True, + ) + create_txs.append(tx) + pytest_keys_per_node.append(pytest_signer_keys) + + cluster.contract_node.send_await_check_txs_parallel( + "create account", create_txs, assert_txn_success + ) + + for near_node, candidate, pytest_signer_keys in zip( + observers, candidates, pytest_keys_per_node + ): + candidate_account_id = candidate.signer_key.account_id + + print( + "ALL ACCESS KEYS:", + cluster.contract_node.near_node.get_access_key_list(candidate_account_id), + ) + creator_key = pytest_signer_keys[0] + print("my signer key:", creator_key.pk) + nonce = cluster.contract_node.near_node.get_nonce_for_pk( + candidate_account_id, creator_key.pk + ) + print("candidate signer key:", candidate.signer_key) + assert nonce is not None + print("found nonce: ", nonce) + # nonce = 0 + # add node access key + tx = sign_create_account_with_multiple_access_keys_tx( + # key, + pytest_signer_keys[0], + candidate_account_id, + [candidate.signer_key], + nonce + 1, + cluster.contract_node.last_block_hash(), + cluster.mpc_contract_account(), + False, + False, ) - txs.append(tx) + access_txs.append(tx) + print("access key tx:") + dump(tx) mpc_node = MpcNode( near_node, @@ -315,7 +471,7 @@ def start_cluster_with_mpc( mpc_nodes.append(mpc_node) cluster.contract_node.send_await_check_txs_parallel( - "create account", txs, assert_txn_success + "access keys", access_txs, assert_txn_success ) # Deploy the mpc contract diff --git a/pytest/common_lib/shared/near_account.py b/pytest/common_lib/shared/near_account.py index 6bc044fe2..1de47ee96 100644 --- a/pytest/common_lib/shared/near_account.py +++ b/pytest/common_lib/shared/near_account.py @@ -65,8 +65,11 @@ def send_await_check_txs_parallel( txns: list[bytes], verification_callback: Callable[[dict[str, Any]], None], ): + print("sending") tx_hashes = self.send_txs_parallel_returning_hashes(txns, label) + print("awaiting") results = self.await_txs(tx_hashes) + print("received", results) verify_txs(results, verification_callback) def get_tx(self, tx_hash): diff --git a/pytest/common_lib/shared/transaction_status.py b/pytest/common_lib/shared/transaction_status.py index fc6b0791c..da59283b0 100644 --- a/pytest/common_lib/shared/transaction_status.py +++ b/pytest/common_lib/shared/transaction_status.py @@ -31,10 +31,17 @@ def verify_txs(results, verification_callback, verbose=False): total_tgas += gas_tx / TGAS total_receipts += n_rcpts_tx verification_callback(res) - if verbose: - print( - f"number of txs: {num_txs}\n max gas used (Tgas):{max_tgas_used}\n average receipts: {total_receipts / num_txs}\n average gas used (Tgas): {total_tgas / num_txs}\n" - ) + if True: # verbose: + if verbose: + if num_txs == 0: + print("number of txs: 0\n no gas or receipts to report") + else: + print( + f"number of txs: {num_txs}\n" + f" max gas used (Tgas):{max_tgas_used}\n" + f" average receipts: {total_receipts / num_txs}\n" + f" average gas used (Tgas): {total_tgas / num_txs}\n" + ) def assert_txn_success(result: dict[str, Any]):