Skip to content

Commit 65832f2

Browse files
committed
feat: add public key to report data and enforcement on the contract
1 parent e224461 commit 65832f2

File tree

13 files changed

+138
-41
lines changed

13 files changed

+138
-41
lines changed

crates/attestation/src/report_data.rs

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use derive_more::Constructor;
33
use near_sdk::PublicKey;
44
use serde::{Deserialize, Serialize};
55
use sha3::{Digest, Sha3_384};
6+
use alloc::vec;
67

78
/// Number of bytes for the report data.
89
const REPORT_DATA_SIZE: usize = 64;
@@ -36,10 +37,11 @@ impl ReportDataVersion {
3637
#[derive(Debug, Clone, Constructor)]
3738
pub struct ReportDataV1 {
3839
tls_public_key: PublicKey,
40+
account_public_key: PublicKey,
3941
}
4042

4143
/// report_data_v1: [u8; 64] =
42-
/// [version(2 bytes big endian) || sha384(TLS pub key) || zero padding]
44+
/// [version(2 bytes big endian) || sha384(TLS pub key || account_pubkey ) || zero padding]
4345
impl ReportDataV1 {
4446
/// V1-specific format constants
4547
const PUBLIC_KEYS_OFFSET: usize = BINARY_VERSION_OFFSET + BINARY_VERSION_SIZE;
@@ -84,12 +86,18 @@ impl ReportDataV1 {
8486
hash
8587
}
8688

87-
/// Generates SHA3-384 hash of TLS public key only.
89+
/// Generates SHA3-384 hash of TLS + NEAR account keys together.
8890
fn public_keys_hash(&self) -> [u8; Self::PUBLIC_KEYS_HASH_SIZE] {
8991
let mut hasher = Sha3_384::new();
90-
// Skip first byte as it is used for identifier for the curve type.
91-
let key_data = &self.tls_public_key.as_bytes()[1..];
92-
hasher.update(key_data);
92+
93+
// Hash TLS key (skip first byte = curve type)
94+
let tls_data = &self.tls_public_key.as_bytes()[1..];
95+
hasher.update(tls_data);
96+
97+
// Hash NEAR account key (also skip first byte)
98+
let account_data = &self.account_public_key.as_bytes()[1..];
99+
hasher.update(account_data);
100+
93101
hasher.finalize().into()
94102
}
95103
}
@@ -100,8 +108,13 @@ pub enum ReportData {
100108
}
101109

102110
impl ReportData {
103-
pub fn new(tls_public_key: PublicKey) -> Self {
104-
ReportData::V1(ReportDataV1::new(tls_public_key))
111+
pub fn new(tls_public_key: PublicKey, account_public_key: Option<PublicKey>) -> Self {
112+
let account_pk = account_public_key.unwrap_or_else(|| {
113+
// Construct a "zero" public key. will not be used in practice, only for backward compatibility. remove this code once that network enforces real attestation
114+
PublicKey::from_parts(near_sdk::CurveType::ED25519, vec![0u8; 32]).expect("valid zero PublicKey")
115+
});
116+
117+
ReportData::V1(ReportDataV1::new(tls_public_key, account_pk))
105118
}
106119

107120
pub fn version(&self) -> ReportDataVersion {
@@ -125,8 +138,15 @@ mod tests {
125138
use alloc::vec::Vec;
126139
use dcap_qvl::quote::Quote;
127140
use near_sdk::PublicKey;
141+
use sha3::{Digest, Sha3_384};
128142
use test_utils::attestation::{p2p_tls_key, quote};
129143

144+
fn create_test_key() -> PublicKey {
145+
"secp256k1:qMoRgcoXai4mBPsdbHi1wfyxF9TdbPCF4qSDQTRP3TfescSRoUdSx6nmeQoN3aiwGzwMyGXAb1gUjBTv5AY8DXj"
146+
.parse()
147+
.unwrap()
148+
}
149+
130150
#[test]
131151
fn test_from_str_valid() {
132152
let valid_quote: Vec<u8> =
@@ -136,14 +156,9 @@ mod tests {
136156
let td_report = quote.report.as_td10().expect("Should be a TD 1.0 report");
137157

138158
let near_p2p_public_key: PublicKey = p2p_tls_key();
139-
let report_data = ReportData::V1(ReportDataV1::new(near_p2p_public_key));
140-
assert_eq!(report_data.to_bytes(), td_report.report_data,);
141-
}
142-
143-
fn create_test_key() -> PublicKey {
144-
"secp256k1:qMoRgcoXai4mBPsdbHi1wfyxF9TdbPCF4qSDQTRP3TfescSRoUdSx6nmeQoN3aiwGzwMyGXAb1gUjBTv5AY8DXj"
145-
.parse()
146-
.unwrap()
159+
let account_key = create_test_key();
160+
let report_data = ReportData::V1(ReportDataV1::new(near_p2p_public_key, account_key));
161+
assert_eq!(report_data.to_bytes(), td_report.report_data);
147162
}
148163

149164
#[test]
@@ -160,11 +175,13 @@ mod tests {
160175
#[test]
161176
fn test_report_data_enum_structure() {
162177
let tls_key = create_test_key();
163-
let data = ReportData::V1(ReportDataV1::new(tls_key.clone()));
178+
let account_key = create_test_key();
179+
let data = ReportData::V1(ReportDataV1::new(tls_key.clone(), account_key.clone()));
164180

165181
match &data {
166182
ReportData::V1(v1) => {
167183
assert_eq!(&v1.tls_public_key, &tls_key);
184+
assert_eq!(&v1.account_public_key, &account_key);
168185
}
169186
}
170187

@@ -174,15 +191,18 @@ mod tests {
174191
#[test]
175192
fn test_report_data_v1_struct() {
176193
let tls_key = create_test_key();
194+
let account_key = create_test_key();
177195

178-
let v1 = ReportDataV1::new(tls_key.clone());
196+
let v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
179197
assert_eq!(v1.tls_public_key, tls_key);
198+
assert_eq!(v1.account_public_key, account_key);
180199
}
181200

182201
#[test]
183202
fn test_from_bytes() {
184203
let tls_key = create_test_key();
185-
let report_data_v1 = ReportDataV1::new(tls_key);
204+
let account_key = create_test_key();
205+
let report_data_v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
186206
let bytes = report_data_v1.to_bytes();
187207

188208
let hash = ReportDataV1::from_bytes(&bytes);
@@ -195,7 +215,8 @@ mod tests {
195215
#[test]
196216
fn test_binary_version_placement() {
197217
let tls_key = create_test_key();
198-
let bytes = ReportDataV1::new(tls_key).to_bytes();
218+
let account_key = create_test_key();
219+
let bytes = ReportDataV1::new(tls_key, account_key).to_bytes();
199220

200221
let version_bytes =
201222
&bytes[BINARY_VERSION_OFFSET..BINARY_VERSION_OFFSET + BINARY_VERSION_SIZE];
@@ -205,20 +226,21 @@ mod tests {
205226
#[test]
206227
fn test_public_key_hash_placement() {
207228
let tls_key = create_test_key();
208-
let report_data_v1 = ReportDataV1::new(tls_key.clone());
229+
let account_key = create_test_key();
230+
let report_data_v1 = ReportDataV1::new(tls_key.clone(), account_key.clone());
209231
let bytes = report_data_v1.to_bytes();
210232

211-
let report_data = ReportData::V1(report_data_v1);
233+
let report_data = ReportData::V1(report_data_v1.clone());
212234
assert_eq!(report_data.to_bytes(), bytes);
213235

214236
let hash_bytes = &bytes[ReportDataV1::PUBLIC_KEYS_OFFSET
215237
..ReportDataV1::PUBLIC_KEYS_OFFSET + ReportDataV1::PUBLIC_KEYS_HASH_SIZE];
216238
assert_ne!(hash_bytes, &[0u8; ReportDataV1::PUBLIC_KEYS_HASH_SIZE]);
217239

240+
// Expected hash = sha3_384(tls || account)
218241
let mut hasher = Sha3_384::new();
219-
// Skip first byte as it is used for identifier for the curve type.
220-
let key_data = &tls_key.as_bytes()[1..];
221-
hasher.update(key_data);
242+
hasher.update(&tls_key.as_bytes()[1..]);
243+
hasher.update(&account_key.as_bytes()[1..]);
222244
let expected: [u8; ReportDataV1::PUBLIC_KEYS_HASH_SIZE] = hasher.finalize().into();
223245

224246
assert_eq!(hash_bytes, &expected);
@@ -227,7 +249,8 @@ mod tests {
227249
#[test]
228250
fn test_zero_padding() {
229251
let tls_key = create_test_key();
230-
let bytes = ReportDataV1::new(tls_key).to_bytes();
252+
let account_key = create_test_key();
253+
let bytes = ReportDataV1::new(tls_key, account_key).to_bytes();
231254

232255
let padding =
233256
&bytes[ReportDataV1::PUBLIC_KEYS_OFFSET + ReportDataV1::PUBLIC_KEYS_HASH_SIZE..];
@@ -237,7 +260,8 @@ mod tests {
237260
#[test]
238261
fn test_report_data_size() {
239262
let tls_key = create_test_key();
240-
let bytes = ReportDataV1::new(tls_key);
263+
let account_key = create_test_key();
264+
let bytes = ReportDataV1::new(tls_key, account_key);
241265
assert_eq!(bytes.to_bytes().len(), REPORT_DATA_SIZE);
242266
}
243267
}

crates/attestation/tests/test_attestation_verification.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use mpc_primitives::hash::{LauncherDockerComposeHash, MpcDockerImageHash};
66
use near_sdk::PublicKey;
77
use rstest::rstest;
88
use test_utils::attestation::{
9-
image_digest, launcher_compose_digest, mock_dstack_attestation, p2p_tls_key,
9+
image_digest, launcher_compose_digest, mock_dstack_attestation, p2p_tls_key, account_key,
1010
};
1111

1212
#[rstest]
@@ -20,7 +20,10 @@ fn test_mock_attestation_verify(
2020
let tls_key = "ed25519:DcA2MzgpJbrUATQLLceocVckhhAqrkingax4oJ9kZ847"
2121
.parse()
2222
.unwrap();
23-
let report_data = ReportData::V1(ReportDataV1::new(tls_key));
23+
let account_key = "ed25519:5v8Y8ZLoxZzCVtYpjh1cYdFrRh1p9EXAMPLEaQJ5sP4o"
24+
.parse()
25+
.unwrap();
26+
let report_data = ReportData::V1(ReportDataV1::new(tls_key,account_key));
2427
let attestation = Attestation::Mock(local_attestation);
2528

2629
assert_eq!(
@@ -33,8 +36,9 @@ fn test_mock_attestation_verify(
3336
fn test_verify_method_signature() {
3437
let attestation = mock_dstack_attestation();
3538
let tls_key: PublicKey = p2p_tls_key();
39+
let account_key: PublicKey = account_key();
3640

37-
let report_data = ReportData::V1(ReportDataV1::new(tls_key));
41+
let report_data = ReportData::V1(ReportDataV1::new(tls_key,account_key));
3842
let timestamp_s = 1755186041_u64;
3943

4044
let allowed_mpc_image_digest: MpcDockerImageHash = image_digest();

crates/contract/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,11 @@ impl MpcContract {
570570
let tee_upgrade_deadline_duration =
571571
Duration::from_secs(self.config.tee_upgrade_deadline_duration_seconds);
572572

573-
// Verify the TEE quote and Docker image for the proposed participant
573+
// Verify the TEE quote (including TLS and account keys) and Docker image for the proposed participant
574574
let status = self.tee_state.verify_proposed_participant_attestation(
575575
&proposed_participant_attestation,
576576
tls_public_key.clone(),
577+
account_key.clone(),
577578
tee_upgrade_deadline_duration,
578579
);
579580

@@ -587,6 +588,7 @@ impl MpcContract {
587588
NodeId {
588589
account_id: account_id.clone(),
589590
tls_public_key,
591+
account_public_key: Some(account_key),
590592
},
591593
proposed_participant_attestation,
592594
);
@@ -1433,6 +1435,7 @@ impl MpcContract {
14331435
// ensure that this node has a valid TEE quote
14341436
let node_id = NodeId {
14351437
account_id: account_id.clone(),
1438+
account_public_key: Some(expected_destination_node.signer_account_pk.clone()),
14361439
tls_public_key: expected_destination_node
14371440
.destination_node_info
14381441
.sign_pk
@@ -2377,6 +2380,7 @@ mod tests {
23772380
NodeId {
23782381
account_id: self.signer_account_id.clone(),
23792382
tls_public_key: self.attestation_tls_key.clone(),
2383+
account_public_key: Some(self.signer_account_pk.clone()),
23802384
},
23812385
valid_participant_attestation,
23822386
);

crates/contract/src/primitives/participants.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl Participants {
210210
.map(|(account_id, _, p_info)| NodeId {
211211
account_id: account_id.clone(),
212212
tls_public_key: p_info.sign_pk.clone(),
213+
account_public_key: None,
213214
})
214215
.collect()
215216
}

crates/contract/src/state/key_event.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl KeyEvent {
105105
.as_mut()
106106
.unwrap()
107107
.vote_success(candidate, public_key)?
108-
{
108+
{
109109
VoteSuccessResult::Voted(count) => {
110110
if count == self.parameters.participants().len() {
111111
Ok(true)

crates/contract/src/tee/tee_state.rs

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
use crate::{
23
primitives::{key_state::AuthenticatedParticipantId, participants::Participants},
34
storage_keys::StorageKey,
@@ -16,14 +17,35 @@ use borsh::{BorshDeserialize, BorshSerialize};
1617
use mpc_primitives::hash::LauncherDockerComposeHash;
1718
use near_sdk::{env, near, store::IterableMap, AccountId, PublicKey};
1819
use std::{collections::HashSet, time::Duration};
20+
use std::hash::{Hash, Hasher};
1921

2022
#[near(serializers=[borsh, json])]
21-
#[derive(Debug, Eq, Ord, PartialEq, PartialOrd, Clone, Hash)]
23+
#[derive(Debug, Ord, PartialOrd, Clone)]
2224
pub struct NodeId {
23-
/// Operator account
25+
/// Operator account (on-chain identity)
2426
pub account_id: AccountId,
2527
/// TLS public key
2628
pub tls_public_key: PublicKey,
29+
/// Account signing public key (optional for backward compatibility)
30+
pub account_public_key: Option<PublicKey>,
31+
}
32+
33+
// Implement Eq + Hash ignoring account_public_key
34+
impl PartialEq for NodeId {
35+
fn eq(&self, other: &Self) -> bool {
36+
self.account_id == other.account_id
37+
&& self.tls_public_key == other.tls_public_key
38+
}
39+
}
40+
41+
impl Eq for NodeId {}
42+
43+
impl Hash for NodeId {
44+
fn hash<H: Hasher>(&self, state: &mut H) {
45+
self.account_id.hash(state);
46+
self.tls_public_key.hash(state);
47+
// intentionally ignoring account_public_key
48+
}
2749
}
2850

2951
pub enum TeeValidationResult {
@@ -64,9 +86,13 @@ impl TeeState {
6486
&mut self,
6587
attestation: &Attestation,
6688
tls_public_key: PublicKey,
89+
account_public_key: PublicKey,
6790
tee_upgrade_deadline_duration: Duration,
6891
) -> TeeQuoteStatus {
69-
let expected_report_data = ReportData::V1(ReportDataV1::new(tls_public_key));
92+
// Recreate the exact same ReportData that the enclave produced
93+
let expected_report_data =
94+
ReportData::V1(ReportDataV1::new(tls_public_key, account_public_key));
95+
7096
let is_valid = attestation.verify(
7197
expected_report_data,
7298
Self::current_time_seconds(),
@@ -96,8 +122,11 @@ impl TeeState {
96122
return TeeQuoteStatus::None;
97123
};
98124

99-
let expected_report_data =
100-
ReportData::V1(ReportDataV1::new(node_id.tls_public_key.clone()));
125+
let expected_report_data = ReportData ::new(
126+
node_id.tls_public_key.clone(),
127+
node_id.account_public_key.clone(),
128+
);
129+
101130
let time_stamp_seconds = Self::current_time_seconds();
102131

103132
let quote_result = participant_attestation.verify(
@@ -132,12 +161,14 @@ impl TeeState {
132161
.iter()
133162
.filter(|(account_id, _, participant_info)| {
134163
let tls_public_key = participant_info.sign_pk.clone();
164+
135165

136166
matches!(
137167
self.verify_tee_participant(
138168
&NodeId {
139169
account_id: account_id.clone(),
140-
tls_public_key
170+
tls_public_key : tls_public_key,
171+
account_public_key: None,
141172
},
142173
tee_upgrade_deadline_duration
143174
),
@@ -222,6 +253,7 @@ impl TeeState {
222253
.map(|(account_id, _, p_info)| NodeId {
223254
account_id: account_id.clone(),
224255
tls_public_key: p_info.sign_pk.clone(),
256+
account_public_key: None,
225257
})
226258
.collect();
227259

@@ -244,6 +276,14 @@ impl TeeState {
244276
pub fn get_tee_accounts(&self) -> Vec<NodeId> {
245277
self.participants_attestations.keys().cloned().collect()
246278
}
279+
280+
/// Find a NodeId by its TLS public key.
281+
pub fn find_node_id_by_tls_key(&self, tls_public_key: &PublicKey) -> Option<NodeId> {
282+
self.participants_attestations
283+
.keys()
284+
.find(|node_id| &node_id.tls_public_key == tls_public_key)
285+
.cloned()
286+
}
247287
}
248288

249289
#[cfg(test)]
@@ -269,6 +309,7 @@ mod tests {
269309
.map(|(account_id, _, p_info)| NodeId {
270310
account_id: account_id.clone(),
271311
tls_public_key: p_info.sign_pk.clone(),
312+
account_public_key: Some(bogus_ed25519_near_public_key()),//TODO check if this is ok.
272313
})
273314
.collect();
274315

@@ -277,6 +318,7 @@ mod tests {
277318

278319
let non_participant_uid = NodeId {
279320
account_id: non_participant.clone(),
321+
account_public_key: Some(bogus_ed25519_near_public_key()),
280322
tls_public_key: bogus_ed25519_near_public_key(),
281323
};
282324
for node_id in &participant_nodes {

0 commit comments

Comments
 (0)