diff --git a/Sources/MRP/Applications/MSRP/MSRPApplication.swift b/Sources/MRP/Applications/MSRP/MSRPApplication.swift index ad1d419a..e0c1268b 100644 --- a/Sources/MRP/Applications/MSRP/MSRPApplication.swift +++ b/Sources/MRP/Applications/MSRP/MSRPApplication.swift @@ -138,8 +138,7 @@ struct MSRPPortState: Sendable { } public final class MSRPApplication: BaseApplication, BaseApplicationEventObserver, - BaseApplicationContextObserver, - ApplicationEventHandler, CustomStringConvertible, @unchecked Sendable where P == P + BaseApplicationContextObserver, CustomStringConvertible, @unchecked Sendable where P == P { private typealias TalkerRegistration = (Participant, any MSRPTalkerValue) @@ -398,50 +397,6 @@ public final class MSRPApplication: BaseApplication, BaseApplication .normalParticipant } - // If an MSRP message is received from a Port with an event value specifying - // the JoinIn or JoinMt message, and if the StreamID (35.2.2.8.2, - // 35.2.2.10.2), and Direction (35.2.1.2) all match those of an attribute - // already registered on that Port, and the Attribute Type (35.2.2.4) or - // FourPackedEvent (35.2.2.7.2) has changed, then the Bridge should behave as - // though an rLv! event (with immediate leavetimer expiration in the - // Registrar state table) was generated for the MAD in the Received MSRP - // Attribute Declarations before the rJoinIn! or rJoinMt! event for the - // attribute in the received message is processed - public func preApplicantEventHandler( - context: EventContext - ) async throws { - guard context.event == .rJoinIn || context.event == .rJoinMt else { return } - - let contextAttributeType = MSRPAttributeType(rawValue: context.attributeType)! - guard let contextDirection = contextAttributeType.direction else { return } - - let contextStreamID = (context.attributeValue as! MSRPStreamIDRepresentable).streamID - - try await context.participant.leaveNow { attributeType, _, attributeValue in - let attributeType = MSRPAttributeType(rawValue: attributeType)! - guard let direction = attributeType.direction else { return false } - let streamID = (attributeValue as! MSRPStreamIDRepresentable).streamID - - // force immediate leave if the streamID and direction match and the - // attribute type has changed - let isIncluded = contextStreamID == streamID && - contextDirection == direction && - contextAttributeType != attributeType - - // NB: we don't handle attribute subtypes here, they are handled in - // Participant.swift (this is something of a leaky abstraction) - if isIncluded { - _logger - .debug( - "MSRP: forcing immediate leave for stream \(streamID) owing to attribute change: \(attributeType)->\(contextAttributeType)" - ) - } - return isIncluded - } - } - - public func postApplicantEventHandler(context: EventContext) {} - // On receipt of a REGISTER_STREAM.request the MSRP Participant shall issue a // MAD_Join.request service primitive (10.2, 10.3). The attribute_type (10.2) // parameter of the request shall carry the appropriate Talker Attribute Type @@ -511,11 +466,10 @@ public final class MSRPApplication: BaseApplication, BaseApplication public func deregisterStream( streamID: MSRPStreamID ) async throws { - let talkerRegistration = try await _findTalkerRegistration(for: streamID) - let declarationType: MSRPDeclarationType - guard let talkerRegistration else { + guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else { throw MRPError.participantNotFound } + let declarationType: MSRPDeclarationType if talkerRegistration.1 is MSRPTalkerAdvertiseValue { declarationType = .talkerAdvertise } else { @@ -578,6 +532,38 @@ public final class MSRPApplication: BaseApplication, BaseApplication } extension MSRPApplication { + // Enforce mutual exclusion between talkerAdvertise and talkerFailed on a participant + private func _enforceTalkerMutualExclusion( + participant: Participant, + declarationType: MSRPDeclarationType, + streamID: MSRPStreamID, + eventSource: EventSource + ) async throws { + let oppositeType: MSRPAttributeType = declarationType == .talkerAdvertise ? .talkerFailed : + .talkerAdvertise + + let oppositeAttributes = await participant.findAttributes( + attributeType: oppositeType.rawValue, + matching: .matchAnyIndex(streamID.index) + ) + + for (_, attributeValue) in oppositeAttributes { + if eventSource == .map { + try? await participant.leave( + attributeType: oppositeType.rawValue, + attributeValue: attributeValue, + eventSource: eventSource + ) + } else { + try? await participant.deregister( + attributeType: oppositeType.rawValue, + attributeValue: attributeValue, + eventSource: eventSource + ) + } + } + } + private func _shouldPruneTalkerDeclaration( port: P, streamID: MSRPStreamID, @@ -892,6 +878,17 @@ extension MSRPApplication { "MSRP: register stream indication from port \(port) streamID \(talkerValue.streamID) declarationType \(declarationType) dataFrameParameters \(talkerValue.dataFrameParameters) isNew \(isNew) source \(eventSource)" ) + // Deregister the opposite talker type from the peer to ensure mutual exclusion + if eventSource == .peer { + let sourceParticipant = try findParticipant(for: contextIdentifier, port: port) + try await _enforceTalkerMutualExclusion( + participant: sourceParticipant, + declarationType: declarationType, + streamID: talkerValue.streamID, + eventSource: .peer + ) + } + // TL;DR: propagate Talker declarations to other ports try await apply(for: contextIdentifier) { participant in guard participant.port != port else { return } // don't propagate to source port @@ -930,12 +927,13 @@ extension MSRPApplication { accumulatedLatency += 500 // clause 35.2.2.8.6, 500ns default } - // if a Talker Failed attribute already exists when a Talker Advertised - // is registered, or vice versa, immediately deregister the existing - // attribute - try await participant.leaveStreamNow( - attributeType: declarationType.talkerComplement, - streamID: talkerValue.streamID + // Leave the opposite talker declaration type to ensure mutual exclusion + // (per spec, only one talker declaration type should exist per stream) + try await _enforceTalkerMutualExclusion( + participant: participant, + declarationType: declarationType, + streamID: talkerValue.streamID, + eventSource: .map ) if declarationType == .talkerAdvertise { @@ -970,11 +968,6 @@ extension MSRPApplication { eventSource: .map ) } catch let error as MSRPFailure { - // Leave any existing talkerAdvertise before joining talkerFailed - try await participant.leaveNow { attributeType, _, attributeValue in - attributeType == MSRPAttributeType.talkerAdvertise.rawValue && - (attributeValue as! MSRPStreamIDRepresentable).streamID == talkerValue.streamID - } let talkerFailed = MSRPTalkerFailedValue( streamID: talkerValue.streamID, dataFrameParameters: talkerValue.dataFrameParameters, @@ -1060,11 +1053,24 @@ extension MSRPApplication { ) else { return } + + // verify talker still exists (guard against race with talker departure) + guard let currentTalker = await _findTalkerRegistration( + for: talkerValue.streamID, + participant: talkerParticipant + ), currentTalker.streamID == talkerValue.streamID else { + _logger + .debug( + "MSRP: talker \(talkerValue.streamID) no longer exists, skipping port parameter update" + ) + return + } + try? await _updatePortParameters( port: participant.port, streamID: listenerRegistration.0.streamID, mergedDeclarationType: mergedDeclarationType, - talkerRegistration: (talkerParticipant, talkerValue) + talkerRegistration: (talkerParticipant, currentTalker) ) } } @@ -1140,13 +1146,14 @@ extension MSRPApplication { for streamID: MSRPStreamID, participant: Participant ) async -> (any MSRPTalkerValue)? { + // TalkerFailed takes precedence over TalkerAdvertise per spec if let value = await participant.findAttribute( - attributeType: MSRPAttributeType.talkerAdvertise.rawValue, + attributeType: MSRPAttributeType.talkerFailed.rawValue, matching: .matchAnyIndex(streamID.index) ) { value.1 as? (any MSRPTalkerValue) } else if let value = await participant.findAttribute( - attributeType: MSRPAttributeType.talkerFailed.rawValue, + attributeType: MSRPAttributeType.talkerAdvertise.rawValue, matching: .matchAnyIndex(streamID.index) ) { value.1 as? (any MSRPTalkerValue) @@ -1157,19 +1164,17 @@ extension MSRPApplication { private func _findTalkerRegistration( for streamID: MSRPStreamID - ) async throws -> TalkerRegistration? { + ) async -> TalkerRegistration? { var talkerRegistration: TalkerRegistration? await apply { participant in guard let participantTalker = await _findTalkerRegistration( for: streamID, participant: participant - ) else { + ), talkerRegistration == nil else { return } - if talkerRegistration == nil { - talkerRegistration = (participant, participantTalker) - } + talkerRegistration = (participant, participantTalker) } return talkerRegistration @@ -1182,17 +1187,9 @@ extension MSRPApplication { talkerRegistration: TalkerRegistration, isJoin: Bool ) async throws -> MSRPDeclarationType? { + var mergedDeclarationType = isJoin ? declarationType : nil let streamID = talkerRegistration.1.streamID - - var mergedDeclarationType: MSRPDeclarationType? = if isJoin { - if talkerRegistration.1 is MSRPTalkerFailedValue { - declarationType == nil ? nil : .listenerAskingFailed - } else { - declarationType - } - } else { - nil - } + var listenerCount = mergedDeclarationType != nil ? 1 : 0 // collect listener declarations from all other ports and merge declaration type await apply(for: contextIdentifier) { participant in @@ -1217,9 +1214,16 @@ extension MSRPApplication { with: mergedDeclarationType ) } + listenerCount += 1 } } + precondition(mergedDeclarationType == nil || listenerCount > 0) + + if talkerRegistration.1 is MSRPTalkerFailedValue, listenerCount > 0 { + mergedDeclarationType = .listenerAskingFailed + } + return mergedDeclarationType } @@ -1435,10 +1439,13 @@ extension MSRPApplication { isNew: Bool, eventSource: EventSource ) async throws { - guard let talkerRegistration = try? await _findTalkerRegistration(for: streamID) else { + guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else { + // no listener attribute propagation if no talker (35.2.4.4.1) + // this is an expected race condition - listener arrives before talker + // when talker arrives, _updateExistingListeners() will process it _logger - .error( - "MSRP: could not find talker registration for listener stream \(streamID)" + .debug( + "MSRP: listener registration for stream \(streamID) received before talker, will be processed when talker arrives" ) return } @@ -1616,7 +1623,7 @@ extension MSRPApplication { // StreamID of the Declaration matches a Stream that the Talker is // transmitting, then the Talker shall stop the transmission for this // Stream, if it is transmitting. - guard let talkerRegistration = try? await _findTalkerRegistration(for: streamID) else { + guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else { return } @@ -1783,25 +1790,6 @@ extension MSRPApplication { } } -private extension MSRPDeclarationType { - var talkerComplement: MSRPAttributeType! { - switch self { - case .talkerAdvertise: .talkerFailed - case .talkerFailed: .talkerAdvertise - default: nil - } - } -} - -private extension Participant { - func leaveStreamNow(attributeType: MSRPAttributeType, streamID: MSRPStreamID) async throws { - try await leaveNow { - $0 == attributeType.rawValue && - ($2 as! MSRPStreamIDRepresentable).streamID == streamID - } - } -} - #if canImport(FlyingFox) extension MSRPApplication: RestApiApplication { func registerRestApiHandlers(for httpServer: HTTPServer) async throws { diff --git a/Sources/MRP/Base/Event.swift b/Sources/MRP/Base/Event.swift index 9424c92a..c63ffdb0 100644 --- a/Sources/MRP/Base/Event.swift +++ b/Sources/MRP/Base/Event.swift @@ -43,6 +43,7 @@ public enum ProtocolEvent: Sendable { case leavetimer // leavetimer has expired (10.7.5.21) case leavealltimer // leavealltimer has expired (10.7.5.22) case periodictimer // periodictimer has expired (10.7.5.23) + case rLvNow // receive Leave message with immediate leavetimer expiration (35.2.6) fileprivate var _r: Bool { switch self { @@ -100,7 +101,7 @@ public enum EventSource: Sendable { case `internal` // event source was transitive via MAP function case map - // event source was a preApplicantEventHandler/postApplicantEventHandler hook + // event source was a event handler hook case application // event source was immediate re-registration after LeaveAll processing case leaveAll diff --git a/Sources/MRP/Model/Applicant.swift b/Sources/MRP/Model/Applicant.swift index e76b6e90..810d9899 100644 --- a/Sources/MRP/Model/Applicant.swift +++ b/Sources/MRP/Model/Applicant.swift @@ -199,6 +199,8 @@ private extension Applicant.State { } case .rLv: fallthrough + case .rLvNow: + fallthrough case .rLA: fallthrough case .ReDeclare: diff --git a/Sources/MRP/Model/Application.swift b/Sources/MRP/Model/Application.swift index 3ceff54c..a2e53502 100644 --- a/Sources/MRP/Model/Application.swift +++ b/Sources/MRP/Model/Application.swift @@ -214,10 +214,3 @@ extension Application { try await apply(for: contextIdentifier) { try await $0.redeclare() } } } - -public protocol ApplicationEventHandler: Application { - associatedtype A: Application - - func preApplicantEventHandler(context: EventContext) async throws - func postApplicantEventHandler(context: EventContext) -} diff --git a/Sources/MRP/Model/Participant.swift b/Sources/MRP/Model/Participant.swift index 84ed2d94..8248d56e 100644 --- a/Sources/MRP/Model/Participant.swift +++ b/Sources/MRP/Model/Participant.swift @@ -98,24 +98,6 @@ private enum EnqueuedEvent: Equatable, CustomStringConvertible { "EnqueuedEvent(\(unsafeAttributeEvent))" } } - - func canBeReplacedBy(_ newEvent: EnqueuedEvent) -> Bool { - if self == newEvent { - // if the event is identical, replace it (effectively, a no-op) - true - } else if let existingAttributeEvent = attributeEvent, - let newAttributeEvent = newEvent.attributeEvent - { - // clause 10.6 suggests the message actually transmitted is "that - // appropriate to the state of the machine when the opportunity is - // presented". I believe this means we can replace any event with the - // same attribute value. We could probably simplify this by just - // transmitting all current attribute values at TX opportunities. - existingAttributeEvent.attributeValue == newAttributeEvent.attributeValue - } else { - false - } - } } public final actor Participant: Equatable, Hashable, CustomStringConvertible { @@ -137,7 +119,6 @@ public final actor Participant: Equatable, Hashable, CustomStrin private var _leaveAll: LeaveAll! private var _jointimer: Timer? private var _transmissionOpportunityTimestamps: [ContinuousClock.Instant] = [] - private var _pendingTransmissionRequest = false private nonisolated let _controller: Weak> private nonisolated let _application: Weak @@ -184,12 +165,7 @@ public final actor Participant: Equatable, Hashable, CustomStrin // instance of this timer is required on a per-Port, per-MRP Participant // basis. The value of JoinTime used to initialize this timer is determined // in accordance with 10.7.11. - - // only required for shared media, in the point-to-point case packets - // are transmitted immediately - let jointimer = Timer(label: "jointimer", onExpiry: _onJoinTimerExpired) - jointimer.start(interval: JoinTime) - _jointimer = jointimer + _jointimer = Timer(label: "jointimer", onExpiry: _onJoinTimerExpired) // The Leave All Period Timer, leavealltimer, controls the frequency with // which the LeaveAll state machine generates LeaveAll PDUs. The timer is @@ -209,34 +185,21 @@ public final actor Participant: Equatable, Hashable, CustomStrin _leaveAll?.stopLeaveAllTimer() } - @Sendable - private func _onJoinTimerExpired() async throws { - guard let controller else { return } - if _type == .pointToPoint && _pendingTransmissionRequest { - // Check if rate limiting window has expired - _pendingTransmissionRequest = false - } - try await _requestTxOpportunity(eventSource: .joinTimer) - // Restart join timer for next period - _jointimer?.start(interval: controller.timerConfiguration.joinTime) - } - @Sendable private func _onLeaveAllTimerExpired() async throws { try await _handleLeaveAll(protocolEvent: .leavealltimer, eventSource: .leaveAllTimer) // Table 10.5: Request opportunity to transmit on entry to the Active state - try await _requestTxOpportunity(eventSource: .leaveAll) + _requestTxOpportunity(eventSource: .leaveAll) } private func _apply( attributeType: AttributeType? = nil, - matching filter: AttributeValueFilter? = nil, + matching filter: AttributeValueFilter = .matchAny, _ block: AsyncParticipantApplyFunction ) async rethrows { for attribute in _attributes { for attributeValue in attribute.value { - if let filter, - !attributeValue.matches(attributeType: attributeType, matching: filter) { continue } + if !attributeValue.matches(attributeType: attributeType, matching: filter) { continue } try await block(attributeValue) } } @@ -268,7 +231,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin ) } - private func _txOpportunity(eventSource: EventSource) async throws { + @Sendable + private func _onJoinTimerExpired() async throws { + let eventSource = EventSource.joinTimer + // this will send a .tx/.txLA event to all attributes which will then make // the appropriate state transitions, potentially triggering the encoding // of a vector @@ -282,50 +248,23 @@ public final actor Participant: Equatable, Hashable, CustomStrin try await _apply(protocolEvent: .tx, eventSource: eventSource) } try await _tx() + + // If events remain (e.g., arrived during TX processing or didn't fit in PDU), + // request another TX opportunity + if !_enqueuedEvents.isEmpty { + _requestTxOpportunity(eventSource: eventSource) + } } - private func _requestTxOpportunity(eventSource: EventSource) async throws { - guard let controller else { throw MRPError.internalError } + private func _requestTxOpportunity(eventSource: EventSource) { + _logger.trace("\(self): \(eventSource) requested TX opportunity") - let now = ContinuousClock.now + guard let jointimer = _jointimer, !jointimer.isRunning else { return } + guard let controller else { return } let joinTime = controller.timerConfiguration.joinTime - - if _type == .pointToPoint { - // Point-to-point: immediate transmission with rate limiting - // Remove timestamps older than 1.5 × JoinTime - let rateWindow = joinTime * 1.5 - _transmissionOpportunityTimestamps.removeAll { $0.duration(to: now) > rateWindow } - - // Check rate limit: max 3 transmissions per 1.5 × JoinTime - guard _transmissionOpportunityTimestamps.count < 3 else { - _logger - .trace( - "\(self): rate limiting TX opportunity, \(_transmissionOpportunityTimestamps.count) transmissions in last \(rateWindow)" - ) - _pendingTransmissionRequest = true - return - } - - _transmissionOpportunityTimestamps.append(now) - try await _txOpportunity(eventSource: eventSource) - } else { - // Shared media: randomized delay between 0 and JoinTime - if !_pendingTransmissionRequest { - _pendingTransmissionRequest = true - let randomDelay = Duration - .nanoseconds(Int64.random(in: 0..: Equatable, Hashable, CustomStrin return attributeValue } - public func findAttribute( + private func _findRegisteredAttributes( attributeType: AttributeType, - matching filter: AttributeValueFilter - ) -> (AttributeSubtype?, any Value)? { - let attributeValue = try? _findOrCreateAttribute( - attributeType: attributeType, - attributeSubtype: nil, - matching: filter, - createIfMissing: false - ) - // we allow attributes that are in the leaving state to be "found", because - // they haven't yet been timed out yet (and a leave indication issued) - guard let attributeValue, attributeValue.isRegistered else { - _logger.trace("\(self): could not find attribute type \(attributeType) matching \(filter)") - return nil - } - return (attributeValue.attributeSubtype, attributeValue.unwrappedValue) - } - - public func findAttributes( - attributeType: AttributeType, - matching filter: AttributeValueFilter - ) -> [(AttributeSubtype?, any Value)] { + matching filter: AttributeValueFilter = .matchAny + ) -> [_AttributeValue] { (_attributes[attributeType] ?? []) .filter { $0.matches(attributeType: attributeType, matching: filter) && $0.isRegistered } - .map { ($0.attributeSubtype, $0.unwrappedValue) } - } - - func findAllAttributes( - matching filter: AttributeValueFilter - ) -> [AttributeValue] { - _attributes.values.flatMap { $0 } - .map { AttributeValue( - attributeType: $0.attributeType, - attributeSubtype: $0.attributeSubtype, - attributeValue: $0.unwrappedValue, - applicantState: $0.applicantState, - registrarState: $0.registrarState - ) } - } - - func findAllAttributes( - attributeType: AttributeType, - matching filter: AttributeValueFilter - ) -> [AttributeValue] { - (_attributes[attributeType] ?? []) - .filter { $0.matches(attributeType: attributeType, matching: filter) } - .map { AttributeValue( - attributeType: attributeType, - attributeSubtype: $0.attributeSubtype, - attributeValue: $0.unwrappedValue, - applicantState: $0.applicantState, - registrarState: $0.registrarState - ) } - } - - public func leaveNow( - _ isIncluded: @Sendable (AttributeType, AttributeSubtype?, any Value) - -> Bool - ) async throws { - try await _leave(eventSource: .application, isLeaveAll: false, isIncluded) } fileprivate func _gcAttributeValue(_ attributeValue: _AttributeValue) { @@ -439,53 +323,27 @@ public final actor Participant: Equatable, Hashable, CustomStrin private func _handleAttributeValue( _ attributeValue: _AttributeValue, protocolEvent: ProtocolEvent, - eventSource: EventSource, - replacingAttributeSubtype: AttributeSubtype? = nil, - gcNow: Bool = false + eventSource: EventSource ) async throws { try await attributeValue.handle( protocolEvent: protocolEvent, - eventSource: eventSource, - replacingAttributeSubtype: replacingAttributeSubtype + eventSource: eventSource ) - - if gcNow, attributeValue.canGC { - _gcAttributeValue(attributeValue) - } } - private func _leave( + private func _leaveAll( eventSource: EventSource, - isLeaveAll: Bool, - _ isIncluded: @Sendable (AttributeType, AttributeSubtype?, any Value) -> Bool + attributeType leaveAllAttributeType: AttributeType ) async throws { - try await _apply { attributeValue in - guard isIncluded( - attributeValue.attributeType, - attributeValue.attributeSubtype, - attributeValue.unwrappedValue - ) else { - return - } - + try await _apply(attributeType: leaveAllAttributeType) { attributeValue in try await _handleAttributeValue( attributeValue, - protocolEvent: isLeaveAll ? .rLA : .rLv, - eventSource: eventSource, - gcNow: true + protocolEvent: .rLA, + eventSource: eventSource ) } } - private func _leaveAll( - eventSource: EventSource, - attributeType leaveAllAttributeType: AttributeType - ) async throws { - try await _leave(eventSource: eventSource, isLeaveAll: true) { attributeType, _, _ in - attributeType == leaveAllAttributeType - } - } - private func _chunkAttributeEvents(_ attributeEvents: [EnqueuedEvent.AttributeEvent]) -> [[EnqueuedEvent.AttributeEvent]] { @@ -565,10 +423,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin return messages } - private func _txEnqueue(_ event: EnqueuedEvent) { + private func _txEnqueue(_ event: EnqueuedEvent, eventSource: EventSource) { if let index = _enqueuedEvents.index(forKey: event.attributeType) { if let eventIndex = _enqueuedEvents.values[index] - .firstIndex(where: { $0.canBeReplacedBy(event) }) + .firstIndex(where: { $0 == event }) { _enqueuedEvents.values[index][eventIndex] = event } else { @@ -577,25 +435,27 @@ public final actor Participant: Equatable, Hashable, CustomStrin } else { _enqueuedEvents[event.attributeType] = [event] } + _requestTxOpportunity(eventSource: eventSource) } fileprivate func _txEnqueue( attributeEvent: AttributeEvent, attributeValue: _AttributeValue, - encodingOptional: Bool + encodingOptional: Bool, + eventSource: EventSource ) { let event = EnqueuedEvent.AttributeEvent( attributeEvent: attributeEvent, attributeValue: attributeValue, encodingOptional: encodingOptional ) - _txEnqueue(.attributeEvent(event)) + _txEnqueue(.attributeEvent(event), eventSource: eventSource) } - private func _txEnqueueLeaveAllEvents() throws { + private func _txEnqueueLeaveAllEvents(eventSource: EventSource) throws { guard let application else { throw MRPError.internalError } for attributeType in application.validAttributeTypes { - _txEnqueue(.leaveAllEvent(attributeType)) + _txEnqueue(.leaveAllEvent(attributeType), eventSource: eventSource) } } @@ -616,7 +476,7 @@ public final actor Participant: Equatable, Hashable, CustomStrin // registered attributes (Table 10-4), as well as requesting the // applicant to redeclare attributes (Table 10-3). try await _apply(protocolEvent: .rLA, eventSource: eventSource) - try _txEnqueueLeaveAllEvents() + try _txEnqueueLeaveAllEvents(eventSource: eventSource) default: break } @@ -637,15 +497,12 @@ public final actor Participant: Equatable, Hashable, CustomStrin return pdu } - private func rx(message: Message, sourceMacAddress: EUI48) async throws { - let eventSource: EventSource = _isEqualMacAddress( - sourceMacAddress, - port.macAddress - ) ? .local : .peer + private func rx(message: Message, eventSource: EventSource, leaveAll: inout Bool) async throws { for vectorAttribute in message.attributeList { // 10.6 Protocol operation: process LeaveAll first. if vectorAttribute.leaveAllEvent == .LeaveAll { try await _leaveAll(eventSource: eventSource, attributeType: message.attributeType) + leaveAll = true } let packedEvents = try vectorAttribute.attributeEvents @@ -663,11 +520,24 @@ public final actor Participant: Equatable, Hashable, CustomStrin createIfMissing: true ) else { continue } + // if a Bridge receives a MSRP JoinIn/JoinMt message with a different + // attribute subtype, it should behave as if a rLv! event with immediate + // leavetimer expiration was received. + if attributeEvent.protocolEvent == .rJoinIn || attributeEvent.protocolEvent == .rJoinMt, + let attributeSubtype, attribute.attributeSubtype != attributeSubtype + { + _logger + .debug( + "\(self): \(eventSource) declared attribute \(attribute) with new subtype \(attributeSubtype); replacing" + ) + try? await attribute.rLvNow(eventSource: eventSource) + attribute.attributeSubtype = attributeSubtype + } + try await _handleAttributeValue( attribute, protocolEvent: attributeEvent.protocolEvent, - eventSource: eventSource, - replacingAttributeSubtype: attributeSubtype + eventSource: eventSource ) } } @@ -675,8 +545,16 @@ public final actor Participant: Equatable, Hashable, CustomStrin func rx(pdu: MRPDU, sourceMacAddress: EUI48) async throws { _debugLogPdu(pdu, direction: .rx) + var leaveAll = false + let eventSource: EventSource = _isEqualMacAddress( + sourceMacAddress, + port.macAddress + ) ? .local : .peer for message in pdu.messages { - try await rx(message: message, sourceMacAddress: sourceMacAddress) + try await rx(message: message, eventSource: eventSource, leaveAll: &leaveAll) + } + if leaveAll { + try await _handleLeaveAll(protocolEvent: .rLA, eventSource: eventSource) } } @@ -754,9 +632,26 @@ public final actor Participant: Equatable, Hashable, CustomStrin _logger .debug("\(self): \(direction): -------------------------------------------------------------") } +} - func periodic() async throws { - try await _requestTxOpportunity(eventSource: .periodicTimer) +// MARK: - public APIs for use by applications + +public extension Participant { + func findAttribute( + attributeType: AttributeType, + matching filter: AttributeValueFilter + ) -> (AttributeSubtype?, any Value)? { + findAttributes(attributeType: attributeType, matching: filter).first + } + + func findAttributes( + attributeType: AttributeType, + matching filter: AttributeValueFilter = .matchAny + ) -> [(AttributeSubtype?, any Value)] { + _findRegisteredAttributes(attributeType: attributeType, matching: filter).map { ( + $0.attributeSubtype, + $0.unwrappedValue + ) } } // A Flush! event signals to the Registrar state machine that there is a @@ -800,11 +695,19 @@ public final actor Participant: Equatable, Hashable, CustomStrin createIfMissing: true ) + if !isNew, let attributeSubtype, attribute.attributeSubtype != attributeSubtype { _logger + .debug( + "\(self): \(eventSource) declared attribute \(attribute) with new subtype \(attributeSubtype); replacing" + ) + + try? await _handleAttributeValue(attribute, protocolEvent: .Lv, eventSource: eventSource) + attribute.attributeSubtype = attributeSubtype + } + try await _handleAttributeValue( attribute, protocolEvent: isNew ? .New : .Join, - eventSource: eventSource, - replacingAttributeSubtype: attributeSubtype + eventSource: eventSource ) } @@ -817,16 +720,70 @@ public final actor Participant: Equatable, Hashable, CustomStrin let attribute = try _findOrCreateAttribute( attributeType: attributeType, attributeSubtype: attributeSubtype, - matching: .matchEqual(attributeValue), // don't match on subtype, we want to replace it + matching: .matchEqual(attributeValue), createIfMissing: false ) try await _handleAttributeValue( attribute, protocolEvent: .Lv, - eventSource: eventSource, - replacingAttributeSubtype: attributeSubtype + eventSource: eventSource + ) + } + + func deregister( + attributeType: AttributeType, + attributeValue: some Value, + eventSource: EventSource + ) async throws { + let attribute = try _findOrCreateAttribute( + attributeType: attributeType, + attributeSubtype: nil, + matching: .matchEqual(attributeValue), + createIfMissing: false ) + + try await _handleAttributeValue( + attribute, + protocolEvent: .rLvNow, + eventSource: eventSource + ) + } + + func periodic() async throws { + _requestTxOpportunity(eventSource: .periodicTimer) + } +} + +// MARK: - for use by REST APIs + +extension Participant { + func findAllAttributes( + matching filter: AttributeValueFilter = .matchAny + ) -> [AttributeValue] { + _attributes.values.flatMap { $0 } + .map { AttributeValue( + attributeType: $0.attributeType, + attributeSubtype: $0.attributeSubtype, + attributeValue: $0.unwrappedValue, + applicantState: $0.applicantState, + registrarState: $0.registrarState + ) } + } + + func findAllAttributes( + attributeType: AttributeType, + matching filter: AttributeValueFilter = .matchAny + ) -> [AttributeValue] { + (_attributes[attributeType] ?? []) + .filter { $0.matches(attributeType: attributeType, matching: filter) } + .map { AttributeValue( + attributeType: attributeType, + attributeSubtype: $0.attributeSubtype, + attributeValue: $0.unwrappedValue, + applicantState: $0.applicantState, + registrarState: $0.registrarState + ) } } } @@ -868,7 +825,7 @@ Sendable, Hashable, Equatable, var participant: P? { _participant.object } var unwrappedValue: any Value { value.value } - private(set) var attributeSubtype: AttributeSubtype? { + var attributeSubtype: AttributeSubtype? { get { _attributeSubtype.withLock { $0 } } @@ -948,10 +905,6 @@ Sendable, Hashable, Equatable, protocolEvent: .leavetimer, eventSource: .leaveTimer ) - - if canGC, let participant { - await participant._gcAttributeValue(self) - } } func hash(into hasher: inout Hasher) { @@ -987,67 +940,61 @@ Sendable, Hashable, Equatable, private func _getEventContext( for event: ProtocolEvent, - eventSource: EventSource - ) async throws -> EventContext { - guard let participant else { throw MRPError.internalError } - - let smFlags = try await participant._getSmFlags(for: attributeType) - - return EventContext( + eventSource: EventSource, + isolation participant: isolated P + ) throws -> EventContext { + try EventContext( participant: participant, event: event, eventSource: eventSource, attributeType: attributeType, attributeSubtype: attributeSubtype, attributeValue: unwrappedValue, - smFlags: smFlags, + smFlags: participant._getSmFlags(for: attributeType), applicant: applicant, registrar: registrar ) } - func handle( + fileprivate func handle( + protocolEvent event: ProtocolEvent, + eventSource: EventSource + ) async throws { + guard let participant else { throw MRPError.internalError } + try await _handle(protocolEvent: event, eventSource: eventSource, isolation: participant) + } + + private func _handle( protocolEvent event: ProtocolEvent, eventSource: EventSource, - replacingAttributeSubtype subtype: AttributeSubtype? = nil + isolation participant: isolated P ) async throws { - // fast path for MSRP pre-applicant event handler: silently replace attribute - // subtypes as if the Listener declaration had been withdrawn and - // replaced by the updated Listener declaration (35.2.6) - if let subtype { attributeSubtype = subtype } + let context = try _getEventContext(for: event, eventSource: eventSource, isolation: participant) - let context = try await _getEventContext(for: event, eventSource: eventSource) + try await _handleRegistrar(context: context, isolation: context.participant) + try await _handleApplicant(context: context, isolation: context.participant) - try await _handleRegistrar(context: context) - try await _handleApplicant(context: context) + // remove attribute entirely if it is no longer declared or registered + if canGC { participant._gcAttributeValue(self) } } - private func _handleApplicant(context: EventContext) async throws { - context.participant._logger.trace("\(context.participant): handling applicant \(context)") + private func _handleApplicant( + context: EventContext, + isolation participant: isolated P + ) throws { + participant._logger.trace("\(context.participant): handling applicant \(context)") - let applicantAction = applicant.action(for: context.event, flags: context.smFlags) + guard let applicantAction = applicant.action(for: context.event, flags: context.smFlags) + else { return } - if let applicantAction { - context.participant._logger - .trace( - "\(context.participant): applicant action for event \(context.event): \(applicantAction)" - ) - let applicationEventHandler = context.participant - .application as? any ApplicationEventHandler - try await applicationEventHandler?.preApplicantEventHandler(context: context) - let attributeEvent = try await _handle(applicantAction: applicantAction, context: context) - applicationEventHandler?.postApplicantEventHandler(context: context) - counters.withLock { $0.count(context: context, attributeEvent: attributeEvent) } - } - } + participant._logger + .trace( + "\(context.participant): applicant action for event \(context.event): \(applicantAction)" + ) - private func _handle( - applicantAction action: Applicant.Action, - context: EventContext - ) async throws -> AttributeEvent? { var attributeEvent: AttributeEvent? - switch action { + switch applicantAction { case .sN: // The AttributeEvent value New is encoded in the Vector as specified in // 10.7.6.1. @@ -1075,37 +1022,36 @@ Sendable, Hashable, Equatable, } if let attributeEvent { - await context.participant._txEnqueue( + participant._txEnqueue( attributeEvent: attributeEvent, attributeValue: self, - encodingOptional: action.encodingOptional + encodingOptional: applicantAction.encodingOptional, + eventSource: context.eventSource ) } - return attributeEvent + counters.withLock { $0.count(context: context, attributeEvent: attributeEvent) } } - private func _handleRegistrar(context: EventContext) async throws { + private func _handleRegistrar( + context: EventContext, + isolation participant: isolated P + ) async throws { context.participant._logger.trace("\(context.participant): handling registrar \(context)") - if let registrarAction = context.registrar?.action(for: context.event, flags: context.smFlags) { - context.participant._logger - .trace( - "\(context.participant): registrar action for event \(context.event): \(registrarAction)" - ) - try await _handle( - registrarAction: registrarAction, - context: context - ) + guard let registrarAction = context.registrar? + .action(for: context.event, flags: context.smFlags) + else { + return } - } - private func _handle( - registrarAction action: Registrar.Action, - context: EventContext - ) async throws { + context.participant._logger + .trace( + "\(context.participant): registrar action for event \(context.event): \(registrarAction)" + ) + guard let application = context.participant.application else { throw MRPError.internalError } - switch action { + switch registrarAction { case .New: fallthrough case .Join: @@ -1115,7 +1061,7 @@ Sendable, Hashable, Equatable, attributeType: context.attributeType, attributeSubtype: context.attributeSubtype, attributeValue: context.attributeValue, - isNew: action == .New, + isNew: registrarAction == .New, eventSource: context.eventSource ) case .Lv: @@ -1129,6 +1075,16 @@ Sendable, Hashable, Equatable, ) } } + + fileprivate func rLvNow( + eventSource: EventSource + ) async throws { + try await handle( + protocolEvent: .rLvNow, + eventSource: eventSource + ) + precondition(!isRegistered) + } } private extension AttributeValueFilter { diff --git a/Sources/MRP/Model/Registrar.swift b/Sources/MRP/Model/Registrar.swift index 08ade1aa..981b519d 100644 --- a/Sources/MRP/Model/Registrar.swift +++ b/Sources/MRP/Model/Registrar.swift @@ -54,6 +54,8 @@ final class Registrar: Sendable, CustomStringConvertible { if state == .LV, event == .rNew || event == .rJoinIn || event == .rJoinMt { leaveTimerAction = .stop + } else if state != .MT, event == .rLvNow { + leaveTimerAction = .stop } else if state == .IN, event == .rLv || event == .rLA || event == .txLA || event == .ReDeclare { @@ -153,6 +155,11 @@ private extension Registrar.State { action = .Lv } self = .MT + case .rLvNow: + // behave as though .rLv was received with immediate leavetimer expiration + guard self != .MT else { break } + action = .Lv + self = .MT default: break }