Skip to content

Commit b0b4617

Browse files
pbezabarakeinav1
authored andcommitted
fix: address post-merge code review comments for PR #1183 (#1189)
1 parent 7f52f01 commit b0b4617

File tree

14 files changed

+205
-75
lines changed

14 files changed

+205
-75
lines changed

crates/attestation/src/report_data.rs

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use derive_more::Constructor;
33
use near_sdk::PublicKey;
44
use serde::{Deserialize, Serialize};
55
use sha3::{Digest, Sha3_384};
6+
use alloc::vec;
67

78
/// Number of bytes for the report data.
89
const REPORT_DATA_SIZE: usize = 64;
@@ -36,10 +37,11 @@ impl ReportDataVersion {
3637
#[derive(Debug, Clone, Constructor)]
3738
pub struct ReportDataV1 {
3839
tls_public_key: PublicKey,
40+
account_public_key: PublicKey,
3941
}
4042

4143
/// report_data_v1: [u8; 64] =
42-
/// [version(2 bytes big endian) || sha384(TLS pub key) || zero padding]
44+
/// [version(2 bytes big endian) || sha384(TLS pub key || account_pubkey ) || zero padding]
4345
impl ReportDataV1 {
4446
/// V1-specific format constants
4547
const PUBLIC_KEYS_OFFSET: usize = BINARY_VERSION_OFFSET + BINARY_VERSION_SIZE;
@@ -84,12 +86,18 @@ impl ReportDataV1 {
8486
hash
8587
}
8688

87-
/// Generates SHA3-384 hash of TLS public key only.
89+
/// Generates SHA3-384 hash of TLS + NEAR account keys together.
8890
fn public_keys_hash(&self) -> [u8; Self::PUBLIC_KEYS_HASH_SIZE] {
8991
let mut hasher = Sha3_384::new();
90-
// Skip first byte as it is used for identifier for the curve type.
91-
let key_data = &self.tls_public_key.as_bytes()[1..];
92-
hasher.update(key_data);
92+
93+
// Hash TLS key (skip first byte = curve type)
94+
let tls_data = &self.tls_public_key.as_bytes()[1..];
95+
hasher.update(tls_data);
96+
97+
// Hash NEAR account key (also skip first byte)
98+
let account_data = &self.account_public_key.as_bytes()[1..];
99+
hasher.update(account_data);
100+
93101
hasher.finalize().into()
94102
}
95103
}
@@ -100,8 +108,13 @@ pub enum ReportData {
100108
}
101109

102110
impl ReportData {
103-
pub fn new(tls_public_key: PublicKey) -> Self {
104-
ReportData::V1(ReportDataV1::new(tls_public_key))
111+
pub fn new(tls_public_key: PublicKey, account_public_key: Option<PublicKey>) -> Self {
112+
let account_pk = account_public_key.unwrap_or_else(|| {
113+
// Construct a "zero" public key. will not be used in practice, only for backward compatibility. remove this code once that network enforces real attestation
114+
PublicKey::from_parts(near_sdk::CurveType::ED25519, vec![0u8; 32]).expect("valid zero PublicKey")
115+
});
116+
117+
ReportData::V1(ReportDataV1::new(tls_public_key, account_pk))
105118
}
106119

107120
pub fn version(&self) -> ReportDataVersion {
@@ -125,8 +138,15 @@ mod tests {
125138
use alloc::vec::Vec;
126139
use dcap_qvl::quote::Quote;
127140
use near_sdk::PublicKey;
141+
use sha3::{Digest, Sha3_384};
128142
use test_utils::attestation::{p2p_tls_key, quote};
129143

144+
fn create_test_key() -> PublicKey {
145+
"secp256k1:qMoRgcoXai4mBPsdbHi1wfyxF9TdbPCF4qSDQTRP3TfescSRoUdSx6nmeQoN3aiwGzwMyGXAb1gUjBTv5AY8DXj"
146+
.parse()
147+
.unwrap()
148+
}
149+
130150
#[test]
131151
fn test_from_str_valid() {
132152
let valid_quote: Vec<u8> =
@@ -136,14 +156,9 @@ mod tests {
136156
let td_report = quote.report.as_td10().expect("Should be a TD 1.0 report");
137157

138158
let near_p2p_public_key: PublicKey = p2p_tls_key();
139-
let report_data = ReportData::V1(ReportDataV1::new(near_p2p_public_key));
140-
assert_eq!(report_data.to_bytes(), td_report.report_data,);
141-
}
142-
143-
fn create_test_key() -> PublicKey {
144-
"secp256k1:qMoRgcoXai4mBPsdbHi1wfyxF9TdbPCF4qSDQTRP3TfescSRoUdSx6nmeQoN3aiwGzwMyGXAb1gUjBTv5AY8DXj"
145-
.parse()
146-
.unwrap()
159+
let account_key = create_test_key();
160+
let report_data = ReportData::V1(ReportDataV1::new(near_p2p_public_key, account_key));
161+
assert_eq!(report_data.to_bytes(), td_report.report_data);
147162
}
148163

149164
#[test]
@@ -160,11 +175,13 @@ mod tests {
160175
#[test]
161176
fn test_report_data_enum_structure() {
162177
let tls_key = create_test_key();
163-
let data = ReportData::V1(ReportDataV1::new(tls_key.clone()));
178+
let account_key = create_test_key();
179+
let data = ReportData::V1(ReportDataV1::new(tls_key.clone(), account_key.clone()));
164180

165181
match &data {
166182
ReportData::V1(v1) => {
167183
assert_eq!(&v1.tls_public_key, &tls_key);
184+
assert_eq!(&v1.account_public_key, &account_key);
168185
}
169186
}
170187

@@ -174,15 +191,18 @@ mod tests {
174191
#[test]
175192
fn test_report_data_v1_struct() {
176193
let tls_key = create_test_key();
194+
let account_key = create_test_key();
177195

178-
let v1 = ReportDataV1::new(tls_key.clone());
196+
let v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
179197
assert_eq!(v1.tls_public_key, tls_key);
198+
assert_eq!(v1.account_public_key, account_key);
180199
}
181200

182201
#[test]
183202
fn test_from_bytes() {
184203
let tls_key = create_test_key();
185-
let report_data_v1 = ReportDataV1::new(tls_key);
204+
let account_key = create_test_key();
205+
let report_data_v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
186206
let bytes = report_data_v1.to_bytes();
187207

188208
let hash = ReportDataV1::from_bytes(&bytes);
@@ -195,7 +215,8 @@ mod tests {
195215
#[test]
196216
fn test_binary_version_placement() {
197217
let tls_key = create_test_key();
198-
let bytes = ReportDataV1::new(tls_key).to_bytes();
218+
let account_key = create_test_key();
219+
let bytes = ReportDataV1::new(tls_key, account_key).to_bytes();
199220

200221
let version_bytes =
201222
&bytes[BINARY_VERSION_OFFSET..BINARY_VERSION_OFFSET + BINARY_VERSION_SIZE];
@@ -205,20 +226,21 @@ mod tests {
205226
#[test]
206227
fn test_public_key_hash_placement() {
207228
let tls_key = create_test_key();
208-
let report_data_v1 = ReportDataV1::new(tls_key.clone());
229+
let account_key = create_test_key();
230+
let report_data_v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
209231
let bytes = report_data_v1.to_bytes();
210232

211-
let report_data = ReportData::V1(report_data_v1);
233+
let report_data = ReportData::V1(report_data_v1.clone());
212234
assert_eq!(report_data.to_bytes(), bytes);
213235

214236
let hash_bytes = &bytes[ReportDataV1::PUBLIC_KEYS_OFFSET
215237
..ReportDataV1::PUBLIC_KEYS_OFFSET + ReportDataV1::PUBLIC_KEYS_HASH_SIZE];
216238
assert_ne!(hash_bytes, &[0u8; ReportDataV1::PUBLIC_KEYS_HASH_SIZE]);
217239

240+
// Expected hash = sha3_384(tls || account)
218241
let mut hasher = Sha3_384::new();
219-
// Skip first byte as it is used for identifier for the curve type.
220-
let key_data = &tls_key.as_bytes()[1..];
221-
hasher.update(key_data);
242+
hasher.update(&tls_key.as_bytes()[1..]);
243+
hasher.update(&account_key.as_bytes()[1..]);
222244
let expected: [u8; ReportDataV1::PUBLIC_KEYS_HASH_SIZE] = hasher.finalize().into();
223245

224246
assert_eq!(hash_bytes, &expected);
@@ -227,7 +249,8 @@ mod tests {
227249
#[test]
228250
fn test_zero_padding() {
229251
let tls_key = create_test_key();
230-
let bytes = ReportDataV1::new(tls_key).to_bytes();
252+
let account_key = create_test_key();
253+
let bytes = ReportDataV1::new(tls_key, account_key).to_bytes();
231254

232255
let padding =
233256
&bytes[ReportDataV1::PUBLIC_KEYS_OFFSET + ReportDataV1::PUBLIC_KEYS_HASH_SIZE..];
@@ -237,7 +260,8 @@ mod tests {
237260
#[test]
238261
fn test_report_data_size() {
239262
let tls_key = create_test_key();
240-
let bytes = ReportDataV1::new(tls_key);
263+
let account_key = create_test_key();
264+
let bytes = ReportDataV1::new(tls_key, account_key);
241265
assert_eq!(bytes.to_bytes().len(), REPORT_DATA_SIZE);
242266
}
243267
}

crates/attestation/tests/test_attestation_verification.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use mpc_primitives::hash::{LauncherDockerComposeHash, MpcDockerImageHash};
66
use near_sdk::PublicKey;
77
use rstest::rstest;
88
use test_utils::attestation::{
9-
image_digest, launcher_compose_digest, mock_dstack_attestation, p2p_tls_key,
9+
image_digest, launcher_compose_digest, mock_dstack_attestation, p2p_tls_key, account_key,
1010
};
1111

1212
#[rstest]
@@ -20,7 +20,10 @@ fn test_mock_attestation_verify(
2020
let tls_key = "ed25519:DcA2MzgpJbrUATQLLceocVckhhAqrkingax4oJ9kZ847"
2121
.parse()
2222
.unwrap();
23-
let report_data = ReportData::V1(ReportDataV1::new(tls_key));
23+
let account_key = "ed25519:5v8Y8ZLoxZzCVtYpjh1cYdFrRh1p9EXAMPLEaQJ5sP4o"
24+
.parse()
25+
.unwrap();
26+
let report_data = ReportData::V1(ReportDataV1::new(tls_key,account_key));
2427
let attestation = Attestation::Mock(local_attestation);
2528

2629
assert_eq!(
@@ -33,8 +36,9 @@ fn test_mock_attestation_verify(
3336
fn test_verify_method_signature() {
3437
let attestation = mock_dstack_attestation();
3538
let tls_key: PublicKey = p2p_tls_key();
39+
let account_key: PublicKey = account_key();
3640

37-
let report_data = ReportData::V1(ReportDataV1::new(tls_key));
41+
let report_data = ReportData::V1(ReportDataV1::new(tls_key,account_key));
3842
let timestamp_s = 1755186041_u64;
3943

4044
let allowed_mpc_image_digest: MpcDockerImageHash = image_digest();

crates/contract/src/lib.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ impl MpcContract {
440440
let signer = env::signer_account_id();
441441
log!("respond: signer={}, request={:?}", &signer, &request);
442442

443+
self.tee_state.assert_caller_is_attested_node();
443444
if !self.protocol_state.is_running_or_resharing() {
444445
return Err(InvalidState::ProtocolStateNotRunning.into());
445446
}
@@ -519,6 +520,8 @@ impl MpcContract {
519520
pub fn respond_ckd(&mut self, request: CKDRequest, response: CKDResponse) -> Result<(), Error> {
520521
let signer = env::signer_account_id();
521522
log!("respond_ckd: signer={}, request={:?}", &signer, &request);
523+
524+
self.tee_state.assert_caller_is_attested_node();
522525

523526
if !self.protocol_state.is_running_or_resharing() {
524527
return Err(InvalidState::ProtocolStateNotRunning.into());
@@ -570,10 +573,11 @@ impl MpcContract {
570573
let tee_upgrade_deadline_duration =
571574
Duration::from_secs(self.config.tee_upgrade_deadline_duration_seconds);
572575

573-
// Verify the TEE quote and Docker image for the proposed participant
576+
// Verify the TEE quote (including TLS and account keys) and Docker image for the proposed participant
574577
let status = self.tee_state.verify_proposed_participant_attestation(
575578
&proposed_participant_attestation,
576579
tls_public_key.clone(),
580+
account_key.clone(),
577581
tee_upgrade_deadline_duration,
578582
);
579583

@@ -587,6 +591,7 @@ impl MpcContract {
587591
NodeId {
588592
account_id: account_id.clone(),
589593
tls_public_key,
594+
account_public_key: Some(account_key),
590595
},
591596
proposed_participant_attestation,
592597
);
@@ -701,6 +706,9 @@ impl MpcContract {
701706
#[handle_result]
702707
pub fn start_keygen_instance(&mut self, key_event_id: KeyEventId) -> Result<(), Error> {
703708
log!("start_keygen_instance: signer={}", env::signer_account_id(),);
709+
710+
self.tee_state.assert_caller_is_attested_node();
711+
704712
self.protocol_state
705713
.start_keygen_instance(key_event_id, self.config.key_event_timeout_blocks)
706714
}
@@ -732,6 +740,8 @@ impl MpcContract {
732740
key_event_id,
733741
public_key,
734742
);
743+
744+
self.tee_state.assert_caller_is_attested_node();
735745

736746
let extended_key =
737747
public_key
@@ -755,6 +765,8 @@ impl MpcContract {
755765
"start_reshare_instance: signer={}",
756766
env::signer_account_id()
757767
);
768+
769+
self.tee_state.assert_caller_is_attested_node();
758770
self.protocol_state
759771
.start_reshare_instance(key_event_id, self.config.key_event_timeout_blocks)
760772
}
@@ -779,7 +791,8 @@ impl MpcContract {
779791
env::signer_account_id(),
780792
key_event_id,
781793
);
782-
794+
795+
self.tee_state.assert_caller_is_attested_node();
783796
let resharing_concluded =
784797
if let Some(new_state) = self.protocol_state.vote_reshared(key_event_id)? {
785798
// Resharing has concluded, transition to running state
@@ -856,7 +869,8 @@ impl MpcContract {
856869
"vote_abort_key_event_instance: signer={}",
857870
env::signer_account_id()
858871
);
859-
872+
873+
self.tee_state.assert_caller_is_attested_node();
860874
self.protocol_state
861875
.vote_abort_key_event_instance(key_event_id)
862876
}
@@ -1004,6 +1018,8 @@ impl MpcContract {
10041018
#[handle_result]
10051019
pub fn verify_tee(&mut self) -> Result<bool, Error> {
10061020
log!("verify_tee: signer={}", env::signer_account_id());
1021+
//caller must be a participant (node or operator)
1022+
let voter = self.voter_or_panic();
10071023
let ProtocolContractState::Running(running_state) = &mut self.protocol_state else {
10081024
return Err(InvalidState::ProtocolStateNotRunning.into());
10091025
};
@@ -1391,8 +1407,8 @@ impl MpcContract {
13911407
/// - `InvalidParameters::InvalidTeeRemoteAttestation`: if destination node’s TEE quote is invalid
13921408
#[handle_result]
13931409
pub fn conclude_node_migration(&mut self, keyset: &Keyset) -> Result<(), Error> {
1394-
let account_id = env::signer_account_id();
1395-
let signer_pk = env::signer_account_pk();
1410+
let account_id: AccountId = env::signer_account_id();
1411+
let signer_pk: PublicKey = env::signer_account_pk();
13961412
log!(
13971413
"conclude_node_migration: signer={:?}, signer_pk={:?} keyset={:?}",
13981414
account_id,
@@ -1433,6 +1449,7 @@ impl MpcContract {
14331449
// ensure that this node has a valid TEE quote
14341450
let node_id = NodeId {
14351451
account_id: account_id.clone(),
1452+
account_public_key: Some(expected_destination_node.signer_account_pk.clone()),
14361453
tls_public_key: expected_destination_node
14371454
.destination_node_info
14381455
.sign_pk
@@ -2377,6 +2394,7 @@ mod tests {
23772394
NodeId {
23782395
account_id: self.signer_account_id.clone(),
23792396
tls_public_key: self.attestation_tls_key.clone(),
2397+
account_public_key: Some(self.signer_account_pk.clone()),
23802398
},
23812399
valid_participant_attestation,
23822400
);

crates/contract/src/primitives/participants.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl Participants {
210210
.map(|(account_id, _, p_info)| NodeId {
211211
account_id: account_id.clone(),
212212
tls_public_key: p_info.sign_pk.clone(),
213+
account_public_key: None,
213214
})
214215
.collect()
215216
}

crates/contract/src/state/key_event.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl KeyEvent {
105105
.as_mut()
106106
.unwrap()
107107
.vote_success(candidate, public_key)?
108-
{
108+
{
109109
VoteSuccessResult::Voted(count) => {
110110
if count == self.parameters.participants().len() {
111111
Ok(true)

0 commit comments

Comments
 (0)