Skip to content
Merged
279 changes: 259 additions & 20 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use std::convert::TryFrom;

use anyhow::{anyhow, bail, Error, Result};
use base64::Engine;
use serde::{Deserialize, Serialize};
use std::io::IsTerminal;
use tpm2_policy::TPMPolicyStep;
Expand Down Expand Up @@ -35,8 +36,9 @@ impl TryFrom<&TPM2Config> for TPMPolicyStep {
match (&cfg.pcr_ids, &cfg.policy_pubkey_path) {
(Some(_), Some(pubkey_path)) => Ok(TPMPolicyStep::Or([
Box::new(TPMPolicyStep::PCRs(
cfg.get_pcr_hash_alg(),
cfg.get_pcr_ids().unwrap(),
cfg.get_pcr_hash_alg()?,
cfg.get_pcr_ids()?
.ok_or_else(|| anyhow!("pcr_ids unexpectedly empty"))?,
Box::new(TPMPolicyStep::NoStep),
)),
Box::new(get_authorized_policy_step(
Expand All @@ -52,8 +54,9 @@ impl TryFrom<&TPM2Config> for TPMPolicyStep {
Box::new(TPMPolicyStep::NoStep),
])),
(Some(_), None) => Ok(TPMPolicyStep::PCRs(
cfg.get_pcr_hash_alg(),
cfg.get_pcr_ids().unwrap(),
cfg.get_pcr_hash_alg()?,
cfg.get_pcr_ids()?
.ok_or_else(|| anyhow!("pcr_ids unexpectedly empty"))?,
Box::new(TPMPolicyStep::NoStep),
)),
(None, Some(pubkey_path)) => {
Expand All @@ -71,36 +74,48 @@ pub(crate) const DEFAULT_POLICY_REF: &str = "";
impl TPM2Config {
pub(super) fn get_pcr_hash_alg(
&self,
) -> tss_esapi::interface_types::algorithm::HashingAlgorithm {
) -> anyhow::Result<tss_esapi::interface_types::algorithm::HashingAlgorithm> {
crate::utils::get_hash_alg_from_name(self.pcr_bank.as_ref())
}

pub(super) fn get_name_hash_alg(
&self,
) -> tss_esapi::interface_types::algorithm::HashingAlgorithm {
) -> anyhow::Result<tss_esapi::interface_types::algorithm::HashingAlgorithm> {
crate::utils::get_hash_alg_from_name(self.hash.as_ref())
}

pub(super) fn get_pcr_ids(&self) -> Option<Vec<u64>> {
pub(super) fn get_pcr_ids(&self) -> Result<Option<Vec<u64>>> {
match &self.pcr_ids {
None => None,
None => Ok(None),
Some(serde_json::Value::Array(vals)) => {
Some(vals.iter().map(|x| x.as_u64().unwrap()).collect())
let ids: Result<Vec<u64>> = vals
.iter()
.map(|x| {
x.as_u64()
.ok_or_else(|| anyhow!("non-u64 value in pcr_ids"))
})
.collect();
Ok(Some(ids?))
}
_ => panic!("Unexpected type found for pcr_ids"),
_ => bail!("Unexpected type found for pcr_ids"),
}
}

pub(super) fn get_pcr_ids_str(&self) -> Option<String> {
pub(super) fn get_pcr_ids_str(&self) -> Result<Option<String>> {
match &self.pcr_ids {
None => None,
Some(serde_json::Value::Array(vals)) => Some(
vals.iter()
.map(|x| x.as_u64().unwrap().to_string())
.collect::<Vec<String>>()
.join(","),
),
_ => panic!("Unexpected type found for pcr_ids"),
None => Ok(None),
Some(serde_json::Value::Array(vals)) => {
let strs: Result<Vec<String>> = vals
.iter()
.map(|x| {
x.as_u64()
.map(|v| v.to_string())
.ok_or_else(|| anyhow!("non-u64 value in pcr_ids"))
})
.collect();
Ok(Some(strs?.join(",")))
}
_ => bail!("Unexpected type found for pcr_ids"),
}
}

Expand All @@ -109,6 +124,53 @@ impl TPM2Config {
if self.pcr_ids.is_some() && self.pcr_bank.is_none() {
self.pcr_bank = Some("sha256".to_string());
}
if let Some(ref hash) = self.hash {
crate::utils::get_hash_alg_from_name(Some(hash))?;
}
if let Some(ref bank) = self.pcr_bank {
crate::utils::get_hash_alg_from_name(Some(bank))?;
}
// tpm2-policy 0.6.0 hardcodes SHA-256 for policy sessions on the
// decrypt path, so non-SHA-256 name hash with PCR binding would
// produce tokens that encrypt successfully but can never be unsealed.
if self.pcr_ids.is_some() {
if let Some(ref hash) = self.hash {
if hash.to_lowercase() != "sha256" {
bail!(
"non-SHA-256 hash is not supported with PCR binding \
(tpm2-policy hardcodes SHA-256 for policy sessions)"
);
}
}
}
if self.pcr_digest.is_some() && self.pcr_ids.is_none() {
bail!("pcr_digest requires pcr_ids");
}
if let Some(ref digest) = self.pcr_digest {
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(digest)
.map_err(|e| anyhow!("invalid pcr_digest base64: {}", e))?;
if decoded.is_empty() {
bail!("pcr_digest must not be empty");
}
if let Some(ref pcr_ids) = self.pcr_ids {
let num_pcrs = match pcr_ids {
serde_json::Value::Array(v) => v.len(),
_ => bail!("pcr_ids has unexpected type (expected array)"),
};
let hash_size = crate::utils::hash_digest_size(self.pcr_bank.as_ref())?;
let expected = num_pcrs * hash_size;
if decoded.len() != expected {
bail!(
"pcr_digest length {} does not match expected {} ({} PCRs * {} bytes)",
decoded.len(),
expected,
num_pcrs,
hash_size
);
}
}
}
// Make use of the defaults if not specified
if self.use_policy.is_some() && self.use_policy.unwrap() {
if self.policy_path.is_none() {
Expand All @@ -126,6 +188,9 @@ impl TPM2Config {
{
eprintln!("To use a policy, please specifiy use_policy: true. Not specifying this will be a fatal error in a next release");
}
if self.pcr_digest.is_some() && self.policy_pubkey_path.is_some() {
bail!("pcr_digest cannot be combined with authorized policy");
}
if (self.policy_pubkey_path.is_some()
|| self.policy_path.is_some()
|| self.policy_ref.is_some())
Expand Down Expand Up @@ -168,6 +233,10 @@ impl TPM2Config {
if !new.is_u64() {
bail!("Non-positive string int");
}
let v = new.as_u64().unwrap();
if v > 23 {
bail!("PCR ID {} out of valid range (0-23)", v);
}
Ok(new)
}
Err(_) => Err(anyhow!("Unparseable string int")),
Expand All @@ -178,6 +247,10 @@ impl TPM2Config {
if !new.is_u64() {
return Err(anyhow!("Non-positive int"));
}
let v = new.as_u64().unwrap();
if v > 23 {
bail!("PCR ID {} out of valid range (0-23)", v);
}
Ok(new)
}
_ => Err(anyhow!("Invalid value in pcr_ids")),
Expand All @@ -186,9 +259,12 @@ impl TPM2Config {
self.pcr_ids = Some(serde_json::Value::Array(newvals?));
}

if let Some(serde_json::Value::Array(ref mut vals)) = self.pcr_ids {
vals.sort_by_key(|v| v.as_u64().unwrap_or(0));
}

match &self.pcr_ids {
None => Ok(()),
// The normalization above would've caught any non-ints
Some(serde_json::Value::Array(_)) => Ok(()),
_ => Err(anyhow!("Invalid type")),
}
Expand Down Expand Up @@ -271,4 +347,167 @@ mod tests {
let result = serde_json::from_str::<TPM2Config>(config_str);
assert!(result.is_ok());
}

#[test]
fn test_pcr_digest_with_policy_rejected() {
let config_str = r#"{"pcr_ids": [23], "pcr_digest": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", "use_policy": true}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("pcr_digest cannot be combined"));
}

#[test]
fn test_pcr_digest_empty_rejected() {
let config_str = r#"{"pcr_ids": [23], "pcr_digest": ""}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("must not be empty"));
}

#[test]
fn test_pcr_digest_invalid_base64_rejected() {
let config_str = r#"{"pcr_ids": [23], "pcr_digest": "not!valid!base64"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("invalid pcr_digest"));
}

#[test]
fn test_pcr_digest_wrong_length_rejected() {
// 27 A's = 20 bytes (SHA-1 size), but pcr_bank defaults to sha256 (32 bytes)
let config_str = r#"{"pcr_ids": [23], "pcr_digest": "AAAAAAAAAAAAAAAAAAAAAAAAAAA"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("does not match expected"));
}

#[test]
fn test_pcr_digest_correct_length_accepted() {
// 43 A's = 32 bytes, matching 1 PCR with default sha256 bank
let config_str =
r#"{"pcr_ids": [23], "pcr_digest": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_ok());
}

#[test]
fn test_pcr_digest_unsupported_bank_rejected() {
let config_str =
r#"{"pcr_ids": [23], "pcr_bank": "md5", "pcr_digest": "AAAAAAAAAAAAAAAA"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unsupported"));
}

#[test]
fn test_unsupported_hash_rejected() {
let config_str = r#"{"hash": "md5"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unsupported"));
}

#[test]
fn test_pcr_id_out_of_range_rejected() {
let config_str = r#"{"pcr_ids": [24]}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("out of valid range"));
}

#[test]
fn test_pcr_id_large_value_rejected() {
let config_str = r#"{"pcr_ids": [4294967296]}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("out of valid range"));
}

#[test]
fn test_pcr_digest_without_pcr_ids_rejected() {
let config_str = r#"{"pcr_digest": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("pcr_digest requires pcr_ids"));
}

#[test]
fn test_non_sha256_hash_with_pcr_ids_rejected() {
let config_str = r#"{"hash": "sha384", "pcr_ids": [7]}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("non-SHA-256 hash"));
}

#[test]
fn test_non_sha256_hash_with_pcr_digest_rejected() {
let config_str =
r#"{"hash": "sha384", "pcr_ids": [7], "pcr_digest": "AAAAAAAAAAAAAAAAAAAAAAAAAAAA"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("non-SHA-256 hash"));
}

#[test]
fn test_non_sha256_hash_without_pcr_ids_accepted() {
let config_str = r#"{"hash": "sha384"}"#;
let result = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize();
assert!(result.is_ok());
}

#[test]
fn test_pcr_ids_sorted_after_normalize() {
let config_str = r#"{"pcr_ids": [23, 7, 0]}"#;
let cfg = serde_json::from_str::<TPM2Config>(config_str)
.unwrap()
.normalize()
.unwrap();
let ids = cfg.get_pcr_ids().unwrap().unwrap();
assert_eq!(ids, vec![0, 7, 23]);
}
}
Loading