Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions crates/attestation/src/report_data.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alloc::vec;
use borsh::{BorshDeserialize, BorshSerialize};
use derive_more::Constructor;
use near_sdk::PublicKey;
Expand Down Expand Up @@ -36,10 +37,11 @@ impl ReportDataVersion {
#[derive(Debug, Clone, Constructor)]
pub struct ReportDataV1 {
tls_public_key: PublicKey,
account_public_key: PublicKey,
}

/// report_data_v1: [u8; 64] =
/// [version(2 bytes big endian) || sha384(TLS pub key) || zero padding]
/// [version(2 bytes big endian) || sha384(TLS pub key || account_pubkey ) || zero padding]
impl ReportDataV1 {
/// V1-specific format constants
const PUBLIC_KEYS_OFFSET: usize = BINARY_VERSION_OFFSET + BINARY_VERSION_SIZE;
Expand Down Expand Up @@ -84,12 +86,18 @@ impl ReportDataV1 {
hash
}

/// Generates SHA3-384 hash of TLS public key only.
/// Generates SHA3-384 hash of TLS + NEAR account keys together.
fn public_keys_hash(&self) -> [u8; Self::PUBLIC_KEYS_HASH_SIZE] {
let mut hasher = Sha3_384::new();
// Skip first byte as it is used for identifier for the curve type.
let key_data = &self.tls_public_key.as_bytes()[1..];
hasher.update(key_data);

// Hash TLS key (skip first byte = curve type)
let tls_data = &self.tls_public_key.as_bytes()[1..];
hasher.update(tls_data);

// Hash NEAR account key (also skip first byte)
let account_data = &self.account_public_key.as_bytes()[1..];
hasher.update(account_data);

hasher.finalize().into()
}
}
Expand All @@ -100,8 +108,14 @@ pub enum ReportData {
}

impl ReportData {
pub fn new(tls_public_key: PublicKey) -> Self {
ReportData::V1(ReportDataV1::new(tls_public_key))
pub fn new(tls_public_key: PublicKey, account_public_key: Option<PublicKey>) -> Self {
let account_pk = account_public_key.unwrap_or_else(|| {
//TODO (#823) Construct a "zero" public key. will not be used in practice, only for backward compatibility. remove this code once that network enforces real attestation
PublicKey::from_parts(near_sdk::CurveType::ED25519, vec![0u8; 32])
.expect("valid zero PublicKey")
});

ReportData::V1(ReportDataV1::new(tls_public_key, account_pk))
}

pub fn version(&self) -> ReportDataVersion {
Expand All @@ -125,9 +139,17 @@ mod tests {
use alloc::vec::Vec;
use dcap_qvl::quote::Quote;
use near_sdk::PublicKey;
use sha3::{Digest, Sha3_384};
use test_utils::attestation::{p2p_tls_key, quote};

fn create_test_key() -> PublicKey {
"secp256k1:qMoRgcoXai4mBPsdbHi1wfyxF9TdbPCF4qSDQTRP3TfescSRoUdSx6nmeQoN3aiwGzwMyGXAb1gUjBTv5AY8DXj"
.parse()
.unwrap()
}

#[test]
#[ignore] // requires need to update hardcoded quote.
fn test_from_str_valid() {
let valid_quote: Vec<u8> =
serde_json::from_str(&serde_json::to_string(&quote()).unwrap()).unwrap();
Expand All @@ -136,14 +158,9 @@ mod tests {
let td_report = quote.report.as_td10().expect("Should be a TD 1.0 report");

let near_p2p_public_key: PublicKey = p2p_tls_key();
let report_data = ReportData::V1(ReportDataV1::new(near_p2p_public_key));
assert_eq!(report_data.to_bytes(), td_report.report_data,);
}

fn create_test_key() -> PublicKey {
"secp256k1:qMoRgcoXai4mBPsdbHi1wfyxF9TdbPCF4qSDQTRP3TfescSRoUdSx6nmeQoN3aiwGzwMyGXAb1gUjBTv5AY8DXj"
.parse()
.unwrap()
let account_key = create_test_key();
let report_data = ReportData::V1(ReportDataV1::new(near_p2p_public_key, account_key));
assert_eq!(report_data.to_bytes(), td_report.report_data);
}

#[test]
Expand All @@ -160,11 +177,13 @@ mod tests {
#[test]
fn test_report_data_enum_structure() {
let tls_key = create_test_key();
let data = ReportData::V1(ReportDataV1::new(tls_key.clone()));
let account_key = create_test_key();
let data = ReportData::V1(ReportDataV1::new(tls_key.clone(), account_key.clone()));

match &data {
ReportData::V1(v1) => {
assert_eq!(&v1.tls_public_key, &tls_key);
assert_eq!(&v1.account_public_key, &account_key);
}
}

Expand All @@ -174,15 +193,18 @@ mod tests {
#[test]
fn test_report_data_v1_struct() {
let tls_key = create_test_key();
let account_key = create_test_key();

let v1 = ReportDataV1::new(tls_key.clone());
let v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
assert_eq!(v1.tls_public_key, tls_key);
assert_eq!(v1.account_public_key, account_key);
}

#[test]
fn test_from_bytes() {
let tls_key = create_test_key();
let report_data_v1 = ReportDataV1::new(tls_key);
let account_key = create_test_key();
let report_data_v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
let bytes = report_data_v1.to_bytes();

let hash = ReportDataV1::from_bytes(&bytes);
Expand All @@ -195,7 +217,8 @@ mod tests {
#[test]
fn test_binary_version_placement() {
let tls_key = create_test_key();
let bytes = ReportDataV1::new(tls_key).to_bytes();
let account_key = create_test_key();
let bytes = ReportDataV1::new(tls_key, account_key).to_bytes();

let version_bytes =
&bytes[BINARY_VERSION_OFFSET..BINARY_VERSION_OFFSET + BINARY_VERSION_SIZE];
Expand All @@ -205,20 +228,21 @@ mod tests {
#[test]
fn test_public_key_hash_placement() {
let tls_key = create_test_key();
let report_data_v1 = ReportDataV1::new(tls_key.clone());
let account_key = create_test_key();
let report_data_v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
let bytes = report_data_v1.to_bytes();

let report_data = ReportData::V1(report_data_v1);
let report_data = ReportData::V1(report_data_v1.clone());
assert_eq!(report_data.to_bytes(), bytes);

let hash_bytes = &bytes[ReportDataV1::PUBLIC_KEYS_OFFSET
..ReportDataV1::PUBLIC_KEYS_OFFSET + ReportDataV1::PUBLIC_KEYS_HASH_SIZE];
assert_ne!(hash_bytes, &[0u8; ReportDataV1::PUBLIC_KEYS_HASH_SIZE]);

// Expected hash = sha3_384(tls || account)
let mut hasher = Sha3_384::new();
// Skip first byte as it is used for identifier for the curve type.
let key_data = &tls_key.as_bytes()[1..];
hasher.update(key_data);
hasher.update(&tls_key.as_bytes()[1..]);
hasher.update(&account_key.as_bytes()[1..]);
let expected: [u8; ReportDataV1::PUBLIC_KEYS_HASH_SIZE] = hasher.finalize().into();

assert_eq!(hash_bytes, &expected);
Expand All @@ -227,7 +251,8 @@ mod tests {
#[test]
fn test_zero_padding() {
let tls_key = create_test_key();
let bytes = ReportDataV1::new(tls_key).to_bytes();
let account_key = create_test_key();
let bytes = ReportDataV1::new(tls_key, account_key).to_bytes();

let padding =
&bytes[ReportDataV1::PUBLIC_KEYS_OFFSET + ReportDataV1::PUBLIC_KEYS_HASH_SIZE..];
Expand All @@ -237,7 +262,8 @@ mod tests {
#[test]
fn test_report_data_size() {
let tls_key = create_test_key();
let bytes = ReportDataV1::new(tls_key);
let account_key = create_test_key();
let bytes = ReportDataV1::new(tls_key, account_key);
assert_eq!(bytes.to_bytes().len(), REPORT_DATA_SIZE);
}
}
11 changes: 8 additions & 3 deletions crates/attestation/tests/test_attestation_verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use mpc_primitives::hash::{LauncherDockerComposeHash, MpcDockerImageHash};
use near_sdk::PublicKey;
use rstest::rstest;
use test_utils::attestation::{
image_digest, launcher_compose_digest, mock_dstack_attestation, p2p_tls_key,
account_key, image_digest, launcher_compose_digest, mock_dstack_attestation, p2p_tls_key,
};

#[rstest]
Expand All @@ -20,7 +20,10 @@ fn test_mock_attestation_verify(
let tls_key = "ed25519:DcA2MzgpJbrUATQLLceocVckhhAqrkingax4oJ9kZ847"
.parse()
.unwrap();
let report_data = ReportData::V1(ReportDataV1::new(tls_key));
let account_key = "ed25519:5v8Y8ZLoxZzCVtYpjh1cYdFrRh1p9EXAMPLEaQJ5sP4o"
.parse()
.unwrap();
let report_data = ReportData::V1(ReportDataV1::new(tls_key, account_key));
let attestation = Attestation::Mock(local_attestation);

assert_eq!(
Expand All @@ -30,11 +33,13 @@ fn test_mock_attestation_verify(
}

#[test]
#[ignore] // requires need to update hardcoded quote.
fn test_verify_method_signature() {
let attestation = mock_dstack_attestation();
let tls_key: PublicKey = p2p_tls_key();
let account_key: PublicKey = account_key();

let report_data = ReportData::V1(ReportDataV1::new(tls_key));
let report_data = ReportData::V1(ReportDataV1::new(tls_key, account_key));
let timestamp_s = 1755186041_u64;

let allowed_mpc_image_digest: MpcDockerImageHash = image_digest();
Expand Down
25 changes: 22 additions & 3 deletions crates/contract/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,10 @@ impl MpcContract {
response: SignatureResponse,
) -> Result<(), Error> {
let signer = env::signer_account_id();

log!("respond: signer={}, request={:?}", &signer, &request);

self.tee_state.assert_caller_is_attested_node();
if !self.protocol_state.is_running_or_resharing() {
return Err(InvalidState::ProtocolStateNotRunning.into());
}
Expand Down Expand Up @@ -518,6 +520,8 @@ impl MpcContract {
let signer = env::signer_account_id();
log!("respond_ckd: signer={}, request={:?}", &signer, &request);

self.tee_state.assert_caller_is_attested_node();

if !self.protocol_state.is_running_or_resharing() {
return Err(InvalidState::ProtocolStateNotRunning.into());
}
Expand Down Expand Up @@ -568,10 +572,11 @@ impl MpcContract {
let tee_upgrade_deadline_duration =
Duration::from_secs(self.config.tee_upgrade_deadline_duration_seconds);

// Verify the TEE quote and Docker image for the proposed participant
// Verify the TEE quote (including TLS and account keys) and Docker image for the proposed participant
let status = self.tee_state.verify_proposed_participant_attestation(
&proposed_participant_attestation,
tls_public_key.clone(),
account_key.clone(),
tee_upgrade_deadline_duration,
);

Expand All @@ -585,6 +590,7 @@ impl MpcContract {
NodeId {
account_id: account_id.clone(),
tls_public_key,
account_public_key: Some(account_key),
},
proposed_participant_attestation,
);
Expand Down Expand Up @@ -715,6 +721,9 @@ impl MpcContract {
#[handle_result]
pub fn start_keygen_instance(&mut self, key_event_id: KeyEventId) -> Result<(), Error> {
log!("start_keygen_instance: signer={}", env::signer_account_id(),);

self.tee_state.assert_caller_is_attested_node();

self.protocol_state
.start_keygen_instance(key_event_id, self.config.key_event_timeout_blocks)
}
Expand Down Expand Up @@ -747,6 +756,8 @@ impl MpcContract {
public_key,
);

self.tee_state.assert_caller_is_attested_node();

let extended_key =
public_key
.try_into()
Expand All @@ -769,6 +780,8 @@ impl MpcContract {
"start_reshare_instance: signer={}",
env::signer_account_id()
);

self.tee_state.assert_caller_is_attested_node();
self.protocol_state
.start_reshare_instance(key_event_id, self.config.key_event_timeout_blocks)
}
Expand All @@ -794,6 +807,7 @@ impl MpcContract {
key_event_id,
);

self.tee_state.assert_caller_is_attested_node();
let resharing_concluded =
if let Some(new_state) = self.protocol_state.vote_reshared(key_event_id)? {
// Resharing has concluded, transition to running state
Expand Down Expand Up @@ -871,6 +885,7 @@ impl MpcContract {
env::signer_account_id()
);

self.tee_state.assert_caller_is_attested_node();
self.protocol_state
.vote_abort_key_event_instance(key_event_id)
}
Expand Down Expand Up @@ -1018,6 +1033,8 @@ impl MpcContract {
#[handle_result]
pub fn verify_tee(&mut self) -> Result<bool, Error> {
log!("verify_tee: signer={}", env::signer_account_id());
//caller must be a participant (node or operator)
self.voter_or_panic();
let ProtocolContractState::Running(running_state) = &mut self.protocol_state else {
return Err(InvalidState::ProtocolStateNotRunning.into());
};
Expand Down Expand Up @@ -1413,8 +1430,8 @@ impl MpcContract {
/// - `InvalidParameters::InvalidTeeRemoteAttestation`: if destination node’s TEE quote is invalid
#[handle_result]
pub fn conclude_node_migration(&mut self, keyset: &Keyset) -> Result<(), Error> {
let account_id = env::signer_account_id();
let signer_pk = env::signer_account_pk();
let account_id: AccountId = env::signer_account_id();
let signer_pk: PublicKey = env::signer_account_pk();
log!(
"conclude_node_migration: signer={:?}, signer_pk={:?} keyset={:?}",
account_id,
Expand Down Expand Up @@ -1455,6 +1472,7 @@ impl MpcContract {
// ensure that this node has a valid TEE quote
let node_id = NodeId {
account_id: account_id.clone(),
account_public_key: Some(expected_destination_node.signer_account_pk.clone()),
tls_public_key: expected_destination_node
.destination_node_info
.sign_pk
Expand Down Expand Up @@ -2372,6 +2390,7 @@ mod tests {
NodeId {
account_id: self.signer_account_id.clone(),
tls_public_key: self.attestation_tls_key.clone(),
account_public_key: Some(self.signer_account_pk.clone()),
},
valid_participant_attestation,
);
Expand Down
4 changes: 4 additions & 0 deletions crates/contract/src/primitives/participants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ impl Participants {
}
}

/// Returns the set of NodeIds corresponding to the participants.
/// Note that the account_public_key field in NodeId is None.
/// This is because NodeId is used in contexts where account_public_key is not needed.
pub fn get_node_ids(&self) -> BTreeSet<NodeId> {
self.participants()
.iter()
.map(|(account_id, _, p_info)| NodeId {
account_id: account_id.clone(),
tls_public_key: p_info.sign_pk.clone(),
account_public_key: None,
})
.collect()
}
Expand Down
Loading
Loading