From eda9b3c297d2f1e11520e12e34d76c64d31e4fe5 Mon Sep 17 00:00:00 2001 From: Jonathan Hoyland Date: Mon, 23 Dec 2024 14:50:55 +0000 Subject: [PATCH] Add support for public extensions in Reports. --- crates/daphne-server/tests/e2e/e2e.rs | 260 ++++++++++++++++++++++ crates/daphne/src/error/aborts.rs | 24 +- crates/daphne/src/hpke.rs | 8 + crates/daphne/src/messages/mod.rs | 91 +++++++- crates/daphne/src/protocol/client.rs | 4 + crates/daphne/src/protocol/mod.rs | 50 +++++ crates/daphne/src/protocol/report_init.rs | 47 +++- crates/daphne/src/roles/leader/mod.rs | 43 +++- crates/daphne/src/roles/mod.rs | 40 ++++ 9 files changed, 548 insertions(+), 19 deletions(-) diff --git a/crates/daphne-server/tests/e2e/e2e.rs b/crates/daphne-server/tests/e2e/e2e.rs index 309bc6d7f..1cfb9ce4f 100644 --- a/crates/daphne-server/tests/e2e/e2e.rs +++ b/crates/daphne-server/tests/e2e/e2e.rs @@ -301,6 +301,10 @@ async fn leader_upload(version: DapVersion) { report_metadata: ReportMetadata { id: ReportId([1; 16]), time: t.now, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_shares: [ @@ -533,6 +537,150 @@ async fn leader_upload_taskprov_wrong_version(version: DapVersion) { async_test_versions!(leader_upload_taskprov_wrong_version); +#[tokio::test] +async fn leader_upload_taskprov_public() { + let version = DapVersion::Latest; + let t = TestRunner::default_with_version(version).await; + let client = t.http_client(); + let hpke_config_list = t.get_hpke_configs(version, client).await.unwrap(); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 10, + query: DapBatchMode::TimeInterval, + leader_url: t.task_config.leader_url.clone(), + helper_url: t.task_config.helper_url.clone(), + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + daphne::roles::aggregator::TaskprovConfig { + hpke_collector_config: &t.taskprov_collector_hpke_receiver.config, + vdaf_verify_key_init: &t.taskprov_vdaf_verify_key_init, + }, + ) + .unwrap(); + + let mut report = task_config + .vdaf + .produce_report( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + version, + ) + .unwrap(); + report.report_metadata.public_extensions = Some(vec![Extension::Taskprov]); + t.leader_request_expect_ok( + client, + &format!("tasks/{}/reports", task_id.to_base64url()), + &http::Method::POST, + DapMediaType::Report, + Some( + &taskprov_advertisement + .serialize_to_header_value(version) + .unwrap(), + ), + report.get_encoded_with_param(&version).unwrap(), + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn leader_upload_taksprov_public_errors() { + let version = DapVersion::Latest; + let t = TestRunner::default_with_version(version).await; + let client = t.http_client(); + let hpke_config_list = t.get_hpke_configs(version, client).await.unwrap(); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 10, + query: DapBatchMode::TimeInterval, + leader_url: t.task_config.leader_url.clone(), + helper_url: t.task_config.helper_url.clone(), + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + daphne::roles::aggregator::TaskprovConfig { + hpke_collector_config: &t.taskprov_collector_hpke_receiver.config, + vdaf_verify_key_init: &t.taskprov_vdaf_verify_key_init, + }, + ) + .unwrap(); + + // Repeated public extension + let mut report = task_config + .vdaf + .produce_report( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + version, + ) + .unwrap(); + report.report_metadata.public_extensions = Some(vec![Extension::Taskprov, Extension::Taskprov]); + t.leader_request_expect_abort( + client, + None, + &format!("tasks/{}/reports", task_id.to_base64url()), + &http::Method::POST, + DapMediaType::Report, + Some( + &taskprov_advertisement + .serialize_to_header_value(version) + .unwrap(), + ), + report.get_encoded_with_param(&version).unwrap(), + 400, + "invalidMessage", + ) + .await + .unwrap(); + + // Unsupported public extension + let mut report = task_config + .vdaf + .produce_report( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + version, + ) + .unwrap(); + report.report_metadata.public_extensions = Some(vec![ + Extension::Taskprov, + Extension::NotImplemented { + typ: 3, + payload: b"ignore".to_vec(), + }, + ]); + t.leader_request_expect_abort( + client, + None, + &format!("tasks/{}/reports", task_id.to_base64url()), + &http::Method::POST, + DapMediaType::Report, + Some( + &taskprov_advertisement + .serialize_to_header_value(version) + .unwrap(), + ), + report.get_encoded_with_param(&version).unwrap(), + 400, + "unsupportedExtension", + ) + .await + .unwrap(); +} + async fn internal_leader_process(version: DapVersion) { let t = TestRunner::default_with_version(version).await; let path = t.upload_path(); @@ -1348,6 +1496,118 @@ async fn leader_selected() { .unwrap(); } +#[tokio::test] +async fn leader_collect_taskprov_repeated_abort() { + let version = DapVersion::Latest; + const DAP_TASKPROV_COLLECTOR_TOKEN: &str = "I-am-the-collector"; + let t = TestRunner::default_with_version(version).await; + let batch_interval = t.batch_interval(); + + let client = t.http_client(); + let hpke_config_list = t.get_hpke_configs(version, client).await.unwrap(); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 10, + query: DapBatchMode::TimeInterval, + leader_url: t.task_config.leader_url.clone(), + helper_url: t.task_config.helper_url.clone(), + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + daphne::roles::aggregator::TaskprovConfig { + hpke_collector_config: &t.taskprov_collector_hpke_receiver.config, + vdaf_verify_key_init: &t.taskprov_vdaf_verify_key_init, + }, + ) + .unwrap(); + + let path = TestRunner::upload_path_for_task(&task_id); + let method = match version { + DapVersion::Draft09 => &Method::PUT, + DapVersion::Latest => &Method::POST, + }; + + // The reports are uploaded in the background. + for _ in 0..t.task_config.min_batch_size { + let extensions = vec![Extension::Taskprov]; + t.leader_request_expect_ok( + client, + &path, + method, + DapMediaType::Report, + Some( + &taskprov_advertisement + .serialize_to_header_value(version) + .unwrap(), + ), + { + let mut report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + extensions, + version, + ) + .unwrap(); + report.report_metadata.public_extensions = Some(vec![Extension::Taskprov]); + report.get_encoded_with_param(&version).unwrap() + }, + ) + .await + .unwrap(); + } + + let agg_param = DapAggregationParam::Empty; + + // Get the collect URI. + let collect_req = CollectionReq { + query: Query::TimeInterval { batch_interval }, + agg_param: agg_param.get_encoded().unwrap(), + }; + let collect_uri = t + .leader_post_collect_using_token( + client, + DAP_TASKPROV_COLLECTOR_TOKEN, + Some(&taskprov_advertisement), + Some(&task_id), + collect_req.get_encoded_with_param(&t.version).unwrap(), + ) + .await + .unwrap(); + println!("collect_uri: {collect_uri}"); + + // Poll the collect URI before the CollectResp is ready. + let resp = t + .poll_collection_url_using_token(client, &collect_uri, DAP_TASKPROV_COLLECTOR_TOKEN) + .await + .unwrap(); + #[expect(clippy::format_in_format_args)] + { + assert_eq!( + resp.status(), + 400, + "response: {} {}", + format!("{resp:?}"), + resp.text().await.unwrap() + ); + } + + // The reports are aggregated in the background. + let agg_telem = t.internal_process(client).await.unwrap(); + assert_eq!( + agg_telem.reports_processed, task_config.min_batch_size, + "reports processed" + ); + assert_eq!(agg_telem.reports_aggregated, 0, "reports aggregated"); + assert_eq!(agg_telem.reports_collected, 0, "reports collected"); +} + async fn leader_collect_taskprov_ok(version: DapVersion) { const DAP_TASKPROV_COLLECTOR_TOKEN: &str = "I-am-the-collector"; let t = TestRunner::default_with_version(version).await; diff --git a/crates/daphne/src/error/aborts.rs b/crates/daphne/src/error/aborts.rs index 6ffede1c6..d77aac4aa 100644 --- a/crates/daphne/src/error/aborts.rs +++ b/crates/daphne/src/error/aborts.rs @@ -95,6 +95,10 @@ pub enum DapAbort { /// Unrecognized DAP task. Sent in response to a request indicating an unrecognized task ID. #[error("unrecognizedTask")] UnrecognizedTask { task_id: TaskId }, + + /// Unsupported Extension. Sent in response to a report upload with an unsupported extension. + #[error("unsupportedExtension")] + UnsupportedExtension { detail: String, task_id: TaskId }, } impl DapAbort { @@ -116,7 +120,8 @@ impl DapAbort { | Self::InvalidBatchSize { detail, task_id } | Self::BatchModeMismatch { detail, task_id } | Self::UnauthorizedRequest { detail, task_id } - | Self::InvalidMessage { detail, task_id } => ( + | Self::InvalidMessage { detail, task_id } + | Self::UnsupportedExtension { detail, task_id } => ( Some(task_id), Some(detail), None, @@ -259,6 +264,20 @@ impl DapAbort { }) } + pub fn unsupported_extension( + task_id: &TaskId, + unknown_extensions: &[u16], + ) -> Result { + let detail = serde_json::to_string(&unknown_extensions); + match detail { + Ok(s) => Ok(Self::UnsupportedExtension { + detail: s, + task_id: *task_id, + }), + Err(x) => Err(fatal_error!(err = %x,)), + } + } + fn title_and_type(&self) -> (&'static str, Option) { let (title, dap_abort_type) = match self { Self::BatchInvalid { .. } => ("Batch boundary check failed", Some(self.to_string())), @@ -300,6 +319,9 @@ impl DapAbort { Some(self.to_string()), ), Self::BadRequest(..) => ("Bad request", None), + Self::UnsupportedExtension { .. } => { + ("Unsupported extensions in report", Some(self.to_string())) + } }; ( diff --git a/crates/daphne/src/hpke.rs b/crates/daphne/src/hpke.rs index 4d770dde7..a8bd6edfc 100644 --- a/crates/daphne/src/hpke.rs +++ b/crates/daphne/src/hpke.rs @@ -612,6 +612,10 @@ mod test { report_metadata: &ReportMetadata { id: ReportId(rand::random()), time: rand::random(), + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, }; let plaintext = b"plaintext"; @@ -703,6 +707,10 @@ mod test { let report_metadata = &ReportMetadata { id: ReportId(rand::random()), time: rand::random(), + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }; let public_share = &vec![rand::random(); (0..100).choose(&mut rand::thread_rng()).unwrap()]; diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index a0e2abbf7..800dcba53 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -177,7 +177,7 @@ pub type Duration = u64; pub type Time = u64; /// Report extensions. -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize, Hash)] #[serde(rename_all = "snake_case")] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] pub enum Extension { @@ -240,28 +240,41 @@ impl ParameterizedDecode for Extension { pub struct ReportMetadata { pub id: ReportId, pub time: Time, + pub public_extensions: Option>, } impl ParameterizedEncode for ReportMetadata { fn encode_with_param( &self, - _version: &DapVersion, + version: &DapVersion, bytes: &mut Vec, ) -> Result<(), CodecError> { self.id.encode(bytes)?; self.time.encode(bytes)?; + match (version, &self.public_extensions) { + (DapVersion::Draft09, None) => (), + (DapVersion::Latest, Some(extensions)) => { + encode_u16_items(bytes, version, extensions.as_slice())?; + } + _ => return Err(CodecError::UnexpectedValue), + } + Ok(()) } } impl ParameterizedDecode for ReportMetadata { fn decode_with_param( - _version: &DapVersion, + version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { let metadata = Self { id: ReportId::decode(bytes)?, time: Time::decode(bytes)?, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(decode_u16_items(version, bytes)?), + }, }; Ok(metadata) @@ -1569,6 +1582,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([23; 16]), time: 1_637_364_244, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_shares: [ @@ -1596,6 +1613,41 @@ mod test { test_versions! {read_report} + fn report_metadata_encode_decode(version: DapVersion) { + let ext_rm = ReportMetadata { + id: ReportId([15; 16]), + time: 123_456, + public_extensions: Some(vec![Extension::NotImplemented { + typ: 0x10, + payload: vec![0x11, 0x12], + }]), + }; + let no_ext_rm = ReportMetadata { + id: ReportId([13; 16]), + time: 123_456, + public_extensions: None, + }; + let good_rm = match version { + DapVersion::Draft09 => &no_ext_rm, + DapVersion::Latest => &ext_rm, + }; + let bad_rm = match version { + DapVersion::Draft09 => &ext_rm, + DapVersion::Latest => &no_ext_rm, + }; + assert!(matches!( + bad_rm.get_encoded_with_param(&version).unwrap_err(), + CodecError::UnexpectedValue + )); + let bytes = good_rm.get_encoded_with_param(&version).unwrap(); + assert_eq!( + ReportMetadata::get_decoded_with_param(&version, bytes.as_slice()).unwrap(), + *good_rm + ); + } + + test_versions! {report_metadata_encode_decode} + fn partial_batch_selector_encode_decode(version: DapVersion) { const TEST_DATA_DRAFT09: &[u8] = &[1]; const TEST_DATA_LATEST: &[u8] = &[1, 0, 0]; @@ -1683,14 +1735,15 @@ mod test { 0, 0, 0, 32, 116, 104, 105, 115, 32, 105, 115, 32, 97, 110, 32, 97, 103, 103, 114, 101, 103, 97, 116, 105, 111, 110, 32, 112, 97, 114, 97, 109, 101, 116, 101, 114, 2, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 158, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, - 0, 0, 0, 0, 97, 152, 38, 185, 0, 0, 0, 12, 112, 117, 98, 108, 105, 99, 32, 115, 104, - 97, 114, 101, 23, 0, 16, 101, 110, 99, 97, 112, 115, 117, 108, 97, 116, 101, 100, 32, - 107, 101, 121, 0, 0, 0, 10, 99, 105, 112, 104, 101, 114, 116, 101, 120, 116, 0, 0, 0, - 10, 112, 114, 101, 112, 32, 115, 104, 97, 114, 101, 17, 17, 17, 17, 17, 17, 17, 17, 17, - 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 9, 194, 107, 103, 0, 0, 0, 12, 112, 117, 98, - 108, 105, 99, 32, 115, 104, 97, 114, 101, 0, 0, 0, 0, 0, 0, 10, 99, 105, 112, 104, 101, - 114, 116, 101, 120, 116, 0, 0, 0, 10, 112, 114, 101, 112, 32, 115, 104, 97, 114, 101, + 0, 0, 0, 0, 0, 0, 162, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, + 0, 0, 0, 0, 97, 152, 38, 185, 0, 0, 0, 0, 0, 12, 112, 117, 98, 108, 105, 99, 32, 115, + 104, 97, 114, 101, 23, 0, 16, 101, 110, 99, 97, 112, 115, 117, 108, 97, 116, 101, 100, + 32, 107, 101, 121, 0, 0, 0, 10, 99, 105, 112, 104, 101, 114, 116, 101, 120, 116, 0, 0, + 0, 10, 112, 114, 101, 112, 32, 115, 104, 97, 114, 101, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 9, 194, 107, 103, 0, 0, 0, 0, 0, 12, 112, + 117, 98, 108, 105, 99, 32, 115, 104, 97, 114, 101, 0, 0, 0, 0, 0, 0, 10, 99, 105, 112, + 104, 101, 114, 116, 101, 120, 116, 0, 0, 0, 10, 112, 114, 101, 112, 32, 115, 104, 97, + 114, 101, ]; let want = AggregationJobInitReq { @@ -1704,6 +1757,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([99; 16]), time: 1_637_361_337, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { @@ -1719,6 +1776,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([17; 16]), time: 163_736_423, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { @@ -1758,6 +1819,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([99; 16]), time: 1_637_361_337, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { @@ -1773,6 +1838,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([17; 16]), time: 163_736_423, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { diff --git a/crates/daphne/src/protocol/client.rs b/crates/daphne/src/protocol/client.rs index 253bd8b97..cbfdc21f3 100644 --- a/crates/daphne/src/protocol/client.rs +++ b/crates/daphne/src/protocol/client.rs @@ -76,6 +76,10 @@ impl VdafConfig { let metadata = ReportMetadata { id: *report_id, time, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }; let encoded_input_shares = input_shares.into_iter().map(|input_share| { diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 3dbb9832e..dac8837f1 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -784,6 +784,56 @@ mod test { test_versions! { handle_unrecognized_report_extensions } + fn handle_unknown_public_extensions_in_report(version: DapVersion) { + let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + let mut report = t + .task_config + .vdaf + .produce_report( + &t.client_hpke_config_list, + t.now, + &t.task_id, + DapMeasurement::U32Vec(vec![1; 10]), + version, + ) + .unwrap(); + report.report_metadata.public_extensions = Some(vec![ + Extension::NotImplemented { + typ: 0x01, + payload: b"This is ignored".to_vec(), + }, + Extension::NotImplemented { + typ: 0x02, + payload: b"This is ignored too".to_vec(), + }, + ]); + let report_metadata = report.report_metadata.clone(); + let [leader_share, _] = report.encrypted_input_shares; + let initialized_report = InitializedReport::from_client( + &t.leader_hpke_receiver_config, + t.valid_report_time_range(), + &t.task_id, + &t.task_config, + ReportShare { + report_metadata: report.report_metadata, + public_share: report.public_share, + encrypted_input_share: leader_share, + }, + &DapAggregationParam::Empty, + ) + .unwrap(); + + assert_eq!(initialized_report.metadata(), &report_metadata); + assert_matches!( + initialized_report, + InitializedReport::Rejected { + report_err: ReportError::InvalidMessage, + .. + } + ); + } + test_versions! {handle_unknown_public_extensions_in_report} + fn handle_repeated_report_extensions(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let report = t diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index c38b5fdd9..abc1c1f4b 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -158,6 +158,28 @@ impl

InitializedReport

{ _ => {} } + // We don't check for duplicates here, because we check for them later + // on, and the taskprov extension, the only one we support, has no + // side-effects if processed when the report should have been rejected. + + let mut taskprov_indicated = false; + match ( + &report_share.report_metadata.public_extensions, + task_config.version, + ) { + (Some(extensions), crate::DapVersion::Latest) => { + for extension in extensions { + match extension { + Extension::Taskprov { .. } => { + taskprov_indicated |= task_config.method_is_taskprov; + } + Extension::NotImplemented { .. } => reject!(InvalidMessage), + } + } + } + (None, crate::DapVersion::Draft09) => (), + (_, _) => reject!(InvalidMessage), + } // decrypt input share let PlaintextInputShare { extensions, @@ -194,18 +216,31 @@ impl

InitializedReport

{ // Handle report extensions. { - if no_duplicates(extensions.iter().map(|e| e.type_code())).is_err() { + // Check for duplicates in public and private extensions + if no_duplicates( + extensions + .iter() + .chain( + report_share + .report_metadata + .public_extensions + .as_deref() + .unwrap_or_default(), + ) + .map(|e| e.type_code()), + ) + .is_err() + { reject!(InvalidMessage) } - let mut taskprov_indicated = false; + for extension in extensions { match extension { - Extension::Taskprov { .. } if task_config.method_is_taskprov => { - taskprov_indicated = true; + Extension::Taskprov { .. } => { + taskprov_indicated |= task_config.method_is_taskprov; } - // Reject reports with unrecognized extensions. - _ => reject!(InvalidMessage), + Extension::NotImplemented { .. } => reject!(InvalidMessage), } } diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index aed51c207..ad0bcd042 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -20,7 +20,7 @@ use crate::{ messages::{ taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq, AggregationJobId, AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId, - CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId, + CollectionReq, Extension, Interval, PartialBatchSelector, Query, Report, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, roles::resolve_task_config, @@ -223,6 +223,47 @@ pub async fn handle_upload_req( .into()); } + match ( + &report.report_metadata.public_extensions, + task_config.version, + ) { + (Some(extensions), DapVersion::Latest) => { + let mut unknown_extensions = Vec::::new(); + if crate::protocol::no_duplicates(extensions.iter()).is_err() { + return Err(DapError::Abort(DapAbort::InvalidMessage { + detail: "Repeated public extension".into(), + task_id, + })); + }; + for extension in extensions { + match extension { + Extension::Taskprov => (), + Extension::NotImplemented { typ, .. } => unknown_extensions.push(*typ), + } + } + + if !unknown_extensions.is_empty() { + return match DapAbort::unsupported_extension(&task_id, &unknown_extensions) { + Ok(abort) => Err::<(), DapError>(abort.into()), + Err(fatal) => Err(fatal), + }; + } + } + (None, DapVersion::Draft09) => (), + (Some(_), DapVersion::Draft09) => { + return Err(DapError::Abort(DapAbort::version_mismatch( + DapVersion::Draft09, + DapVersion::Latest, + ))) + } + (None, DapVersion::Latest) => { + return Err(DapError::Abort(DapAbort::version_mismatch( + DapVersion::Latest, + DapVersion::Draft09, + ))) + } + } + // Store the report for future processing. At this point, the report may be rejected if // the Leader detects that the report was replayed or pertains to a batch that has already // been collected. diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 0754a9a61..8a982b59e 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -765,6 +765,46 @@ mod test { async_test_versions! { handle_agg_job_req_failure_hpke_decrypt_error } + async fn handle_unknown_public_extensions(version: DapVersion) { + let t = Test::new(version); + let task_id = &t.time_interval_task_id; + let task_config = t.leader.unchecked_get_task_config(task_id).await; + let mut report = t.gen_test_report(task_id).await; + report.report_metadata.public_extensions = Some(vec![Extension::NotImplemented { + typ: 0x01, + payload: vec![0x01], + }]); + + let req = DapRequest { + meta: DapRequestMeta { + version: task_config.version, + media_type: Some(DapMediaType::Report), + task_id: *task_id, + ..Default::default() + }, + resource_id: (), + payload: report, + }; + match version { + DapVersion::Draft09 => assert_eq!( + leader::handle_upload_req(&*t.leader, req).await, + Err(DapError::Abort(DapAbort::version_mismatch( + DapVersion::Draft09, + DapVersion::Latest + ))) + ), + DapVersion::Latest => assert_eq!( + leader::handle_upload_req(&*t.leader, req).await, + Err(DapError::Abort(DapAbort::UnsupportedExtension { + detail: "[1]".into(), + task_id: *task_id + })) + ), + } + } + + async_test_versions! { handle_unknown_public_extensions } + async fn handle_agg_job_req_transition_continue(version: DapVersion) { let t = Test::new(version); let task_id = &t.time_interval_task_id;