Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 89 additions & 101 deletions Sources/MRP/Applications/MSRP/MSRPApplication.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ struct MSRPPortState<P: AVBPort>: Sendable {
}

public final class MSRPApplication<P: AVBPort>: BaseApplication, BaseApplicationEventObserver,
BaseApplicationContextObserver,
ApplicationEventHandler, CustomStringConvertible, @unchecked Sendable where P == P
BaseApplicationContextObserver, CustomStringConvertible, @unchecked Sendable where P == P
{
private typealias TalkerRegistration = (Participant<MSRPApplication>, any MSRPTalkerValue)

Expand Down Expand Up @@ -398,50 +397,6 @@ public final class MSRPApplication<P: AVBPort>: BaseApplication, BaseApplication
.normalParticipant
}

// If an MSRP message is received from a Port with an event value specifying
// the JoinIn or JoinMt message, and if the StreamID (35.2.2.8.2,
// 35.2.2.10.2), and Direction (35.2.1.2) all match those of an attribute
// already registered on that Port, and the Attribute Type (35.2.2.4) or
// FourPackedEvent (35.2.2.7.2) has changed, then the Bridge should behave as
// though an rLv! event (with immediate leavetimer expiration in the
// Registrar state table) was generated for the MAD in the Received MSRP
// Attribute Declarations before the rJoinIn! or rJoinMt! event for the
// attribute in the received message is processed
public func preApplicantEventHandler(
context: EventContext<MSRPApplication>
) async throws {
guard context.event == .rJoinIn || context.event == .rJoinMt else { return }

let contextAttributeType = MSRPAttributeType(rawValue: context.attributeType)!
guard let contextDirection = contextAttributeType.direction else { return }

let contextStreamID = (context.attributeValue as! MSRPStreamIDRepresentable).streamID

try await context.participant.leaveNow { attributeType, _, attributeValue in
let attributeType = MSRPAttributeType(rawValue: attributeType)!
guard let direction = attributeType.direction else { return false }
let streamID = (attributeValue as! MSRPStreamIDRepresentable).streamID

// force immediate leave if the streamID and direction match and the
// attribute type has changed
let isIncluded = contextStreamID == streamID &&
contextDirection == direction &&
contextAttributeType != attributeType

// NB: we don't handle attribute subtypes here, they are handled in
// Participant.swift (this is something of a leaky abstraction)
if isIncluded {
_logger
.debug(
"MSRP: forcing immediate leave for stream \(streamID) owing to attribute change: \(attributeType)->\(contextAttributeType)"
)
}
return isIncluded
}
}

public func postApplicantEventHandler(context: EventContext<MSRPApplication>) {}

// On receipt of a REGISTER_STREAM.request the MSRP Participant shall issue a
// MAD_Join.request service primitive (10.2, 10.3). The attribute_type (10.2)
// parameter of the request shall carry the appropriate Talker Attribute Type
Expand Down Expand Up @@ -511,11 +466,10 @@ public final class MSRPApplication<P: AVBPort>: BaseApplication, BaseApplication
public func deregisterStream(
streamID: MSRPStreamID
) async throws {
let talkerRegistration = try await _findTalkerRegistration(for: streamID)
let declarationType: MSRPDeclarationType
guard let talkerRegistration else {
guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else {
throw MRPError.participantNotFound
}
let declarationType: MSRPDeclarationType
if talkerRegistration.1 is MSRPTalkerAdvertiseValue {
declarationType = .talkerAdvertise
} else {
Expand Down Expand Up @@ -578,6 +532,38 @@ public final class MSRPApplication<P: AVBPort>: BaseApplication, BaseApplication
}

extension MSRPApplication {
// Enforce mutual exclusion between talkerAdvertise and talkerFailed on a participant
private func _enforceTalkerMutualExclusion(
participant: Participant<MSRPApplication>,
declarationType: MSRPDeclarationType,
streamID: MSRPStreamID,
eventSource: EventSource
) async throws {
let oppositeType: MSRPAttributeType = declarationType == .talkerAdvertise ? .talkerFailed :
.talkerAdvertise

let oppositeAttributes = await participant.findAttributes(
attributeType: oppositeType.rawValue,
matching: .matchAnyIndex(streamID.index)
)

for (_, attributeValue) in oppositeAttributes {
if eventSource == .map {
try? await participant.leave(
attributeType: oppositeType.rawValue,
attributeValue: attributeValue,
eventSource: eventSource
)
} else {
try? await participant.deregister(
attributeType: oppositeType.rawValue,
attributeValue: attributeValue,
eventSource: eventSource
)
}
}
}

private func _shouldPruneTalkerDeclaration(
port: P,
streamID: MSRPStreamID,
Expand Down Expand Up @@ -892,6 +878,17 @@ extension MSRPApplication {
"MSRP: register stream indication from port \(port) streamID \(talkerValue.streamID) declarationType \(declarationType) dataFrameParameters \(talkerValue.dataFrameParameters) isNew \(isNew) source \(eventSource)"
)

// Deregister the opposite talker type from the peer to ensure mutual exclusion
if eventSource == .peer {
let sourceParticipant = try findParticipant(for: contextIdentifier, port: port)
try await _enforceTalkerMutualExclusion(
participant: sourceParticipant,
declarationType: declarationType,
streamID: talkerValue.streamID,
eventSource: .peer
)
}

// TL;DR: propagate Talker declarations to other ports
try await apply(for: contextIdentifier) { participant in
guard participant.port != port else { return } // don't propagate to source port
Expand Down Expand Up @@ -930,12 +927,13 @@ extension MSRPApplication {
accumulatedLatency += 500 // clause 35.2.2.8.6, 500ns default
}

// if a Talker Failed attribute already exists when a Talker Advertised
// is registered, or vice versa, immediately deregister the existing
// attribute
try await participant.leaveStreamNow(
attributeType: declarationType.talkerComplement,
streamID: talkerValue.streamID
// Leave the opposite talker declaration type to ensure mutual exclusion
// (per spec, only one talker declaration type should exist per stream)
try await _enforceTalkerMutualExclusion(
participant: participant,
declarationType: declarationType,
streamID: talkerValue.streamID,
eventSource: .map
)

if declarationType == .talkerAdvertise {
Expand Down Expand Up @@ -970,11 +968,6 @@ extension MSRPApplication {
eventSource: .map
)
} catch let error as MSRPFailure {
// Leave any existing talkerAdvertise before joining talkerFailed
try await participant.leaveNow { attributeType, _, attributeValue in
attributeType == MSRPAttributeType.talkerAdvertise.rawValue &&
(attributeValue as! MSRPStreamIDRepresentable).streamID == talkerValue.streamID
}
let talkerFailed = MSRPTalkerFailedValue(
streamID: talkerValue.streamID,
dataFrameParameters: talkerValue.dataFrameParameters,
Expand Down Expand Up @@ -1060,11 +1053,24 @@ extension MSRPApplication {
) else {
return
}

// verify talker still exists (guard against race with talker departure)
guard let currentTalker = await _findTalkerRegistration(
for: talkerValue.streamID,
participant: talkerParticipant
), currentTalker.streamID == talkerValue.streamID else {
_logger
.debug(
"MSRP: talker \(talkerValue.streamID) no longer exists, skipping port parameter update"
)
return
}

try? await _updatePortParameters(
port: participant.port,
streamID: listenerRegistration.0.streamID,
mergedDeclarationType: mergedDeclarationType,
talkerRegistration: (talkerParticipant, talkerValue)
talkerRegistration: (talkerParticipant, currentTalker)
)
}
}
Expand Down Expand Up @@ -1140,13 +1146,14 @@ extension MSRPApplication {
for streamID: MSRPStreamID,
participant: Participant<MSRPApplication>
) async -> (any MSRPTalkerValue)? {
// TalkerFailed takes precedence over TalkerAdvertise per spec
if let value = await participant.findAttribute(
attributeType: MSRPAttributeType.talkerAdvertise.rawValue,
attributeType: MSRPAttributeType.talkerFailed.rawValue,
matching: .matchAnyIndex(streamID.index)
) {
value.1 as? (any MSRPTalkerValue)
} else if let value = await participant.findAttribute(
attributeType: MSRPAttributeType.talkerFailed.rawValue,
attributeType: MSRPAttributeType.talkerAdvertise.rawValue,
matching: .matchAnyIndex(streamID.index)
) {
value.1 as? (any MSRPTalkerValue)
Expand All @@ -1157,19 +1164,17 @@ extension MSRPApplication {

private func _findTalkerRegistration(
for streamID: MSRPStreamID
) async throws -> TalkerRegistration? {
) async -> TalkerRegistration? {
var talkerRegistration: TalkerRegistration?

await apply { participant in
guard let participantTalker = await _findTalkerRegistration(
for: streamID,
participant: participant
) else {
), talkerRegistration == nil else {
return
}
if talkerRegistration == nil {
talkerRegistration = (participant, participantTalker)
}
talkerRegistration = (participant, participantTalker)
}

return talkerRegistration
Expand All @@ -1182,17 +1187,9 @@ extension MSRPApplication {
talkerRegistration: TalkerRegistration,
isJoin: Bool
) async throws -> MSRPDeclarationType? {
var mergedDeclarationType = isJoin ? declarationType : nil
let streamID = talkerRegistration.1.streamID

var mergedDeclarationType: MSRPDeclarationType? = if isJoin {
if talkerRegistration.1 is MSRPTalkerFailedValue {
declarationType == nil ? nil : .listenerAskingFailed
} else {
declarationType
}
} else {
nil
}
var listenerCount = mergedDeclarationType != nil ? 1 : 0

// collect listener declarations from all other ports and merge declaration type
await apply(for: contextIdentifier) { participant in
Expand All @@ -1217,9 +1214,16 @@ extension MSRPApplication {
with: mergedDeclarationType
)
}
listenerCount += 1
}
}

precondition(mergedDeclarationType == nil || listenerCount > 0)

if talkerRegistration.1 is MSRPTalkerFailedValue, listenerCount > 0 {
mergedDeclarationType = .listenerAskingFailed
}

return mergedDeclarationType
}

Expand Down Expand Up @@ -1435,10 +1439,13 @@ extension MSRPApplication {
isNew: Bool,
eventSource: EventSource
) async throws {
guard let talkerRegistration = try? await _findTalkerRegistration(for: streamID) else {
guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else {
// no listener attribute propagation if no talker (35.2.4.4.1)
// this is an expected race condition - listener arrives before talker
// when talker arrives, _updateExistingListeners() will process it
_logger
.error(
"MSRP: could not find talker registration for listener stream \(streamID)"
.debug(
"MSRP: listener registration for stream \(streamID) received before talker, will be processed when talker arrives"
)
return
}
Expand Down Expand Up @@ -1616,7 +1623,7 @@ extension MSRPApplication {
// StreamID of the Declaration matches a Stream that the Talker is
// transmitting, then the Talker shall stop the transmission for this
// Stream, if it is transmitting.
guard let talkerRegistration = try? await _findTalkerRegistration(for: streamID) else {
guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else {
return
}

Expand Down Expand Up @@ -1783,25 +1790,6 @@ extension MSRPApplication {
}
}

private extension MSRPDeclarationType {
var talkerComplement: MSRPAttributeType! {
switch self {
case .talkerAdvertise: .talkerFailed
case .talkerFailed: .talkerAdvertise
default: nil
}
}
}

private extension Participant {
func leaveStreamNow(attributeType: MSRPAttributeType, streamID: MSRPStreamID) async throws {
try await leaveNow {
$0 == attributeType.rawValue &&
($2 as! MSRPStreamIDRepresentable).streamID == streamID
}
}
}

#if canImport(FlyingFox)
extension MSRPApplication: RestApiApplication {
func registerRestApiHandlers(for httpServer: HTTPServer) async throws {
Expand Down
3 changes: 2 additions & 1 deletion Sources/MRP/Base/Event.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public enum ProtocolEvent: Sendable {
case leavetimer // leavetimer has expired (10.7.5.21)
case leavealltimer // leavealltimer has expired (10.7.5.22)
case periodictimer // periodictimer has expired (10.7.5.23)
case rLvNow // receive Leave message with immediate leavetimer expiration (35.2.6)

fileprivate var _r: Bool {
switch self {
Expand Down Expand Up @@ -100,7 +101,7 @@ public enum EventSource: Sendable {
case `internal`
// event source was transitive via MAP function
case map
// event source was a preApplicantEventHandler/postApplicantEventHandler hook
// event source was a event handler hook
case application
// event source was immediate re-registration after LeaveAll processing
case leaveAll
Expand Down
2 changes: 2 additions & 0 deletions Sources/MRP/Model/Applicant.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ private extension Applicant.State {
}
case .rLv:
fallthrough
case .rLvNow:
fallthrough
case .rLA:
fallthrough
case .ReDeclare:
Expand Down
7 changes: 0 additions & 7 deletions Sources/MRP/Model/Application.swift
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,3 @@ extension Application {
try await apply(for: contextIdentifier) { try await $0.redeclare() }
}
}

public protocol ApplicationEventHandler<A>: Application {
associatedtype A: Application

func preApplicantEventHandler(context: EventContext<A>) async throws
func postApplicantEventHandler(context: EventContext<A>)
}
Loading