diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 085fda948629c7..b09c4131b21587 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -57,6 +57,32 @@ using Transport::SecureSession; namespace { Global gGroupPeerTable; + +/// RAII class for iterators that guarantees that Release() will be called +/// on the underlying type +template +class AutoRelease +{ +public: + AutoRelease(Releasable * iter) : mIter(iter) {} + ~AutoRelease() { Release(); } + + Releasable * operator->() { return mIter; } + const Releasable * operator->() const { return mIter; } + + bool IsNull() const { return mIter == nullptr; } + + void Release() + { + VerifyOrReturn(mIter != nullptr); + mIter->Release(); + mIter = nullptr; + } + +private: + Releasable * mIter = nullptr; +}; + } // namespace uint32_t EncryptedPacketBufferHandle::GetMessageCounter() const @@ -868,8 +894,11 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack // Trial decryption with GroupDataProvider Credentials::GroupDataProvider::GroupSession groupContext; - auto iter = groups->IterateGroupSessions(partialPacketHeader.GetSessionId()); - if (iter == nullptr) + + AutoRelease iter( + groups->IterateGroupSessions(partialPacketHeader.GetSessionId())); + + if (iter.IsNull()) { ChipLogError(Inet, "Failed to retrieve Groups iterator. Discarding everything"); return; @@ -916,7 +945,7 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack } #endif // CHIP_CONFIG_PRIVACY_ACCEPT_NONSPEC_SVE2 } - iter->Release(); + iter.Release(); if (!decrypted) { @@ -954,7 +983,6 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack gGroupPeerTable->FindOrAddPeer(groupContext.fabric_index, packetHeaderCopy.GetSourceNodeId().Value(), packetHeaderCopy.IsSecureSessionControlMsg(), counter)) { - if (Credentials::GroupDataProvider::SecurityPolicy::kTrustFirst == groupContext.security_policy) { err = counter->VerifyOrTrustFirstGroup(packetHeaderCopy.GetMessageCounter());