Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CASESession] refactoring and improving testability of Sigma1 sending and handling, and Sigma2 Sending #36679

Merged
merged 19 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 109 additions & 89 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
#include <protocols/secure_channel/SessionResumptionStorage.h>
#include <protocols/secure_channel/StatusReport.h>
#include <system/SystemClock.h>
#include <system/TLVPacketBufferBackingStore.h>
#include <tracing/macros.h>
#include <tracing/metric_event.h>
#include <transport/SessionManager.h>
Expand All @@ -68,16 +67,13 @@ enum
kTag_TBSData_ReceiverPubKey = 4,
};

enum
{
kTag_Sigma1_InitiatorRandom = 1,
kTag_Sigma1_InitiatorSessionId = 2,
kTag_Sigma1_DestinationId = 3,
kTag_Sigma1_InitiatorEphPubKey = 4,
kTag_Sigma1_InitiatorMRPParams = 5,
kTag_Sigma1_ResumptionID = 6,
kTag_Sigma1_InitiatorResumeMIC = 7,
};
inline constexpr uint8_t kInitiatorRandomTag = 1;
inline constexpr uint8_t kInitiatorSessionIdTag = 2;
inline constexpr uint8_t kDestinationIdTag = 3;
inline constexpr uint8_t kInitiatorPubKeyTag = 4;
inline constexpr uint8_t kInitiatorMRPParamsTag = 5;
inline constexpr uint8_t kResumptionIDTag = 6;
inline constexpr uint8_t kResume1MICTag = 7;

enum
{
Expand Down Expand Up @@ -770,24 +766,19 @@ void CASESession::HandleConnectionClosed(Transport::ActiveTCPConnectionState * c
CHIP_ERROR CASESession::SendSigma1()
{
MATTER_TRACE_SCOPE("SendSigma1", "CASESession");
size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom
sizeof(uint16_t), // initiatorSessionId,
kSHA256_Hash_Length, // destinationId
kP256_PublicKey_Length, // InitiatorEphPubKey,
SessionParameters::kEstimatedTLVSize, // initiatorSessionParams
SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES);

System::PacketBufferTLVWriter tlvWriter;
System::PacketBufferHandle msg_R1;
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;
uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 };

Sigma1Param encodeSigma1Params;

// Lookup fabric info.
const auto * fabricInfo = mFabricsTable->FindFabricWithIndex(mFabricIndex);
VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INCORRECT_STATE);

// Validate that we have a session ID allocated.
VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE);
encodeSigma1Params.initiatorSessionId = GetLocalSessionId().Value();

// Generate an ephemeral keypair
mEphemeralKey = mFabricsTable->AllocateEphemeralKeypairForCASE();
Expand All @@ -797,16 +788,6 @@ CHIP_ERROR CASESession::SendSigma1()
// Fill in the random value
ReturnErrorOnFailure(DRBG_get_bytes(mInitiatorRandom, sizeof(mInitiatorRandom)));

// Construct Sigma1 Msg
msg_R1 = System::PacketBufferHandle::New(data_len);
VerifyOrReturnError(!msg_R1.IsNull(), CHIP_ERROR_NO_MEMORY);

tlvWriter.Init(std::move(msg_R1));
ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(mInitiatorRandom)));
// Retrieve Session Identifier
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value()));

// Generate a Destination Identifier based on the node we are attempting to reach
{
// Obtain originator IPK matching the fabric where we are trying to open a session. mIPK
Expand All @@ -821,14 +802,10 @@ CHIP_ERROR CASESession::SendSigma1()
MutableByteSpan destinationIdSpan(destinationIdentifier);
ReturnErrorOnFailure(GenerateCaseDestinationId(ByteSpan(mIPK), ByteSpan(mInitiatorRandom), rootPubKeySpan, fabricId,
mPeerNodeId, destinationIdSpan));
encodeSigma1Params.destinationId = destinationIdSpan;
}
ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(3), destinationIdentifier, sizeof(destinationIdentifier)));

ReturnErrorOnFailure(
tlvWriter.PutBytes(TLV::ContextTag(4), mEphemeralKey->Pubkey(), static_cast<uint32_t>(mEphemeralKey->Pubkey().Length())));

VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter));

// Try to find persistent session, and resume it.
bool resuming = false;
Expand All @@ -839,20 +816,20 @@ CHIP_ERROR CASESession::SendSigma1()
if (err == CHIP_NO_ERROR)
{
// Found valid resumption state, try to resume the session.
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(6), mResumeResumptionId));

uint8_t initiatorResume1MIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES];
MutableByteSpan resumeMICSpan(initiatorResume1MIC);
MutableByteSpan resumeMICSpan(encodeSigma1Params.initiatorResume1MIC);
ReturnErrorOnFailure(GenerateSigmaResumeMIC(ByteSpan(mInitiatorRandom), ByteSpan(mResumeResumptionId),
ByteSpan(kKDFS1RKeyInfo), ByteSpan(kResume1MIC_Nonce), resumeMICSpan));

ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(7), resumeMICSpan));
encodeSigma1Params.initiatorResumeMICSpan = resumeMICSpan;
encodeSigma1Params.sessionResumptionRequested = true;

resuming = true;
}
}

ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R1));
// Encode Sigma1 into into msg_R1
ReturnErrorOnFailure(EncodeSigma1(msg_R1, encodeSigma1Params));

ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() }));

Expand Down Expand Up @@ -884,6 +861,52 @@ CHIP_ERROR CASESession::SendSigma1()
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::EncodeSigma1(System::PacketBufferHandle & msg, Sigma1Param & inputParams)
{

MATTER_TRACE_SCOPE("EncodeSigma1", "CASESession");

size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom
sizeof(uint16_t), // initiatorSessionId,
kSHA256_Hash_Length, // destinationId
kP256_PublicKey_Length, // InitiatorEphPubKey,
SessionParameters::kEstimatedTLVSize, // initiatorSessionParams
SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES);

msg = System::PacketBufferHandle::New(data_len);
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY);

System::PacketBufferTLVWriter tlvWriter;
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;

tlvWriter.Init(std::move(msg));
ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
// TODO Pass this in the struct?
Alami-Amine marked this conversation as resolved.
Show resolved Hide resolved
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kInitiatorRandomTag), ByteSpan(mInitiatorRandom)));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kInitiatorSessionIdTag), inputParams.initiatorSessionId));

ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kDestinationIdTag), inputParams.destinationId));

// TODO Pass this in the struct?
Alami-Amine marked this conversation as resolved.
Show resolved Hide resolved
ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kInitiatorPubKeyTag), mEphemeralKey->Pubkey(),
static_cast<uint32_t>(mEphemeralKey->Pubkey().Length())));

// TODO is it redudunt?
Alami-Amine marked this conversation as resolved.
Show resolved Hide resolved
VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(kInitiatorMRPParamsTag), mLocalMRPConfig.Value(), tlvWriter));

if (inputParams.sessionResumptionRequested)
{
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kResumptionIDTag), mResumeResumptionId));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kResume1MICTag), inputParams.initiatorResumeMICSpan));
}

ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize(&msg));

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::HandleSigma1_and_SendSigma2(System::PacketBufferHandle && msg)
{
MATTER_TRACE_SCOPE("HandleSigma1_and_SendSigma2", "CASESession");
Expand Down Expand Up @@ -923,7 +946,7 @@ CHIP_ERROR CASESession::FindLocalNodeFromDestinationId(const ByteSpan & destinat
MutableByteSpan candidateDestinationIdSpan(candidateDestinationId);
ByteSpan candidateIpkSpan(ipkKeySet.epoch_keys[keyIdx].key);

err = GenerateCaseDestinationId(ByteSpan(candidateIpkSpan), ByteSpan(initiatorRandom), rootPubKeySpan, fabricId, nodeId,
err = GenerateCaseDestinationId(candidateIpkSpan, initiatorRandom, rootPubKeySpan, fabricId, nodeId,
candidateDestinationIdSpan);
if ((err == CHIP_NO_ERROR) && (candidateDestinationIdSpan.data_equal(destinationId)))
{
Expand Down Expand Up @@ -974,38 +997,43 @@ CHIP_ERROR CASESession::TryResumeSession(SessionResumptionStorage::ConstResumpti
CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
{
MATTER_TRACE_SCOPE("HandleSigma1", "CASESession");
CHIP_ERROR err = CHIP_NO_ERROR;
System::PacketBufferTLVReader tlvReader;

uint16_t initiatorSessionId;
ByteSpan destinationIdentifier;
ByteSpan initiatorRandom;

ChipLogProgress(SecureChannel, "Received Sigma1 msg");
MATTER_TRACE_COUNTER("Sigma1");

bool sessionResumptionRequested = false;
ByteSpan resumptionId;
ByteSpan resume1MIC;
ByteSpan initiatorPubKey;
CHIP_ERROR err = CHIP_NO_ERROR;
System::PacketBufferTLVReader tlvReader;

Sigma1Param parsedSigma1;

SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() }));

tlvReader.Init(std::move(msg));
tcarmelveilleux marked this conversation as resolved.
Show resolved Hide resolved
SuccessOrExit(err = ParseSigma1(tlvReader, initiatorRandom, initiatorSessionId, destinationIdentifier, initiatorPubKey,
sessionResumptionRequested, resumptionId, resume1MIC));

ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId);
SetPeerSessionId(initiatorSessionId);
SuccessOrExit(err = ParseSigma1(tlvReader, parsedSigma1));

ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", parsedSigma1.initiatorSessionId);
SetPeerSessionId(parsedSigma1.initiatorSessionId);

VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE);

if (sessionResumptionRequested && resumptionId.size() == SessionResumptionStorage::kResumptionIdSize &&
// TODO: Added by Amine, taken from inside ParseSigma1
// This was removed to remove the non-parsing parts from ParseSigma1, decoupling it from higher levels
// TODO: Should i change it?
// Set the recieved MRP parameters included with Sigma1
if (parsedSigma1.InitiatorMRPParamsPresent == true)
{
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
GetRemoteSessionParameters());
Alami-Amine marked this conversation as resolved.
Show resolved Hide resolved
}

if (parsedSigma1.sessionResumptionRequested &&
parsedSigma1.resumptionId.size() == SessionResumptionStorage::kResumptionIdSize &&
CHIP_NO_ERROR ==
TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(resumptionId.data()), resume1MIC, initiatorRandom))
TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(parsedSigma1.resumptionId.data()),
parsedSigma1.initiatorResumeMICSpan, parsedSigma1.initiatorRandom))
{
std::copy(initiatorRandom.begin(), initiatorRandom.end(), mInitiatorRandom);
std::copy(resumptionId.begin(), resumptionId.end(), mResumeResumptionId.begin());
std::copy(parsedSigma1.initiatorRandom.begin(), parsedSigma1.initiatorRandom.end(), mInitiatorRandom);
std::copy(parsedSigma1.resumptionId.begin(), parsedSigma1.resumptionId.end(), mResumeResumptionId.begin());

// Send Sigma2Resume message to the initiator
MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2Resume);
Expand All @@ -1023,7 +1051,7 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
}

// Attempt to match the initiator's desired destination based on local fabric table.
err = FindLocalNodeFromDestinationId(destinationIdentifier, initiatorRandom);
err = FindLocalNodeFromDestinationId(parsedSigma1.destinationId, parsedSigma1.initiatorRandom);
if (err == CHIP_NO_ERROR)
{
ChipLogProgress(SecureChannel, "CASE matched destination ID: fabricIndex %u, NodeID 0x" ChipLogFormatX64,
Expand All @@ -1035,13 +1063,13 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
else
{
ChipLogError(SecureChannel, "CASE failed to match destination ID with local fabrics");
ChipLogByteSpan(SecureChannel, destinationIdentifier);
ChipLogByteSpan(SecureChannel, parsedSigma1.destinationId);
}
SuccessOrExit(err);

// ParseSigma1 ensures that:
// mRemotePubKey.Length() == initiatorPubKey.size() == kP256_PublicKey_Length.
memcpy(mRemotePubKey.Bytes(), initiatorPubKey.data(), mRemotePubKey.Length());
memcpy(mRemotePubKey.Bytes(), parsedSigma1.initiatorEphPubKey.data(), mRemotePubKey.Length());

MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2);
err = SendSigma2();
Expand Down Expand Up @@ -2163,46 +2191,36 @@ CHIP_ERROR CASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralS
return err;
}

CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, ByteSpan & initiatorRandom,
uint16_t & initiatorSessionId, ByteSpan & destinationId, ByteSpan & initiatorEphPubKey,
bool & resumptionRequested, ByteSpan & resumptionId, ByteSpan & initiatorResumeMIC)
CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, Sigma1Param & output)
{
using namespace TLV;

constexpr uint8_t kInitiatorRandomTag = 1;
constexpr uint8_t kInitiatorSessionIdTag = 2;
constexpr uint8_t kDestinationIdTag = 3;
constexpr uint8_t kInitiatorPubKeyTag = 4;
constexpr uint8_t kInitiatorMRPParamsTag = 5;
constexpr uint8_t kResumptionIDTag = 6;
constexpr uint8_t kResume1MICTag = 7;

TLVType containerType = kTLVType_Structure;
ReturnErrorOnFailure(tlvReader.Next(containerType, AnonymousTag()));
ReturnErrorOnFailure(tlvReader.EnterContainer(containerType));

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorRandomTag)));
ReturnErrorOnFailure(tlvReader.GetByteView(initiatorRandom));
VerifyOrReturnError(initiatorRandom.size() == kSigmaParamRandomNumberSize, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.initiatorRandom));
VerifyOrReturnError(output.initiatorRandom.size() == kSigmaParamRandomNumberSize, CHIP_ERROR_INVALID_CASE_PARAMETER);

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorSessionIdTag)));
ReturnErrorOnFailure(tlvReader.Get(initiatorSessionId));
ReturnErrorOnFailure(tlvReader.Get(output.initiatorSessionId));

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kDestinationIdTag)));
ReturnErrorOnFailure(tlvReader.GetByteView(destinationId));
VerifyOrReturnError(destinationId.size() == kSHA256_Hash_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.destinationId));
VerifyOrReturnError(output.destinationId.size() == kSHA256_Hash_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorPubKeyTag)));
ReturnErrorOnFailure(tlvReader.GetByteView(initiatorEphPubKey));
VerifyOrReturnError(initiatorEphPubKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.initiatorEphPubKey));
VerifyOrReturnError(output.initiatorEphPubKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);

// Optional members start here.
CHIP_ERROR err = tlvReader.Next();
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag))
{
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader));
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
GetRemoteSessionParameters());
output.InitiatorMRPParamsPresent = true;

err = tlvReader.Next();
}

Expand All @@ -2212,16 +2230,18 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kResumptionIDTag))
{
resumptionIDTagFound = true;
ReturnErrorOnFailure(tlvReader.GetByteView(resumptionId));
VerifyOrReturnError(resumptionId.size() == SessionResumptionStorage::kResumptionIdSize, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.resumptionId));
VerifyOrReturnError(output.resumptionId.size() == SessionResumptionStorage::kResumptionIdSize,
CHIP_ERROR_INVALID_CASE_PARAMETER);
err = tlvReader.Next();
}

if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kResume1MICTag))
{
resume1MICTagFound = true;
ReturnErrorOnFailure(tlvReader.GetByteView(initiatorResumeMIC));
VerifyOrReturnError(initiatorResumeMIC.size() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.initiatorResumeMICSpan));
VerifyOrReturnError(output.initiatorResumeMICSpan.size() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES,
CHIP_ERROR_INVALID_CASE_PARAMETER);
err = tlvReader.Next();
}

Expand All @@ -2236,11 +2256,11 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,

if (resumptionIDTagFound && resume1MICTagFound)
{
resumptionRequested = true;
output.sessionResumptionRequested = true;
}
else if (!resumptionIDTagFound && !resume1MICTagFound)
{
resumptionRequested = false;
output.sessionResumptionRequested = false;
}
else
{
Expand Down
Loading
Loading