diff --git a/go.mod b/go.mod index 5b6385600..5b84ea929 100644 --- a/go.mod +++ b/go.mod @@ -10,19 +10,20 @@ require ( github.com/lucas-clemente/quic-go v0.18.1 // indirect github.com/lucsky/cuid v1.0.2 github.com/marten-seemann/qtls-go1-15 v0.1.1 // indirect - github.com/pion/ion-log v0.0.0-20201018162658-5afa48038e76 + github.com/matryer/moq v0.1.3 // indirect + github.com/pion/ion-log v0.0.0-20201024224650-e6b94dfeaf1d github.com/pion/rtcp v1.2.4 github.com/pion/rtp v1.6.1 github.com/pion/sdp/v3 v3.0.2 github.com/pion/turn/v2 v2.0.5 // indirect - github.com/pion/webrtc/v3 v3.0.0-beta.10 - github.com/rs/zerolog v1.19.0 // indirect + github.com/pion/webrtc/v3 v3.0.0-beta.10.0.20201025013753-76bc99210140 + github.com/rs/zerolog v1.20.0 // indirect github.com/sourcegraph/jsonrpc2 v0.0.0-20200429184054-15c2290dcb37 github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 // indirect - golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 // indirect - golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13 // indirect + golang.org/x/net v0.0.0-20201024042810-be3efd7ff127 // indirect + golang.org/x/sys v0.0.0-20201024232916-9f70ab9862d5 // indirect google.golang.org/grpc v1.33.1 google.golang.org/protobuf v1.25.0 gopkg.in/ini.v1 v1.51.1 // indirect diff --git a/go.sum b/go.sum index a073c7e81..2f6ff932a 100644 --- a/go.sum +++ b/go.sum @@ -184,6 +184,8 @@ github.com/marten-seemann/qtls v0.10.0/go.mod h1:UvMd1oaYDACI99/oZUYLzMCkBXQVT0a github.com/marten-seemann/qtls-go1-15 v0.1.0/go.mod h1:GyFwywLKkRt+6mfU99csTEY1joMZz5vmB1WNZH3P81I= github.com/marten-seemann/qtls-go1-15 v0.1.1 h1:LIH6K34bPVttyXnUWixk0bzH6/N07VxbSabxn5A5gZQ= github.com/marten-seemann/qtls-go1-15 v0.1.1/go.mod h1:GyFwywLKkRt+6mfU99csTEY1joMZz5vmB1WNZH3P81I= +github.com/matryer/moq v0.1.3 h1:+fW3u2jmlPw59a3V6spZKOLCcvrDKzPjMsRvUhnZ/c0= +github.com/matryer/moq v0.1.3/go.mod h1:9RtPYjTnH1bSBIkpvtHkFN7nbWAnO7oRpdJkEIn6UtE= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -224,6 +226,8 @@ github.com/pion/ice/v2 v2.0.9 h1:oHbiN6Q9tgb8Gfu3I4cbr5mHRE1uqiuFABQ8CbWjIyk= github.com/pion/ice/v2 v2.0.9/go.mod h1:NK+o39ynb+N1YSj9fPgWs3vjVcrsWw0KCr/311MqVq8= github.com/pion/ion-log v0.0.0-20201018162658-5afa48038e76 h1:e1+7hmitdiDKC3lGwTG3IqDXP+9+kjtr6Aa7iZ8tXHg= github.com/pion/ion-log v0.0.0-20201018162658-5afa48038e76/go.mod h1:turscGxpzm5X2PUoMMqAObZPHvk3iujBWcR6bnj1fWY= +github.com/pion/ion-log v0.0.0-20201024224650-e6b94dfeaf1d h1:lG2DmuOV2bSoncAABwIvFOi4/yptRsk/n9PkXbMAxT0= +github.com/pion/ion-log v0.0.0-20201024224650-e6b94dfeaf1d/go.mod h1:jwcla9KoB9bB/4FxYDSRJPcPYSLp5XiUUMnOLaqwl4E= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY= @@ -256,6 +260,8 @@ github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI= github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths= github.com/pion/webrtc/v3 v3.0.0-beta.10 h1:1aBn9jv/oe4v2Uf47HutWIjg2i2ZP/O7HqpgKPqSuhE= github.com/pion/webrtc/v3 v3.0.0-beta.10/go.mod h1:GlriYYHJ5KkNsCunm3oFDPql4TDTrrNoI9iSWWSnafA= +github.com/pion/webrtc/v3 v3.0.0-beta.10.0.20201025013753-76bc99210140 h1:lrXBeoiA3A2fjytKlmvgc3E18to3jNf9V8RBPCVfGWs= +github.com/pion/webrtc/v3 v3.0.0-beta.10.0.20201025013753-76bc99210140/go.mod h1:GlriYYHJ5KkNsCunm3oFDPql4TDTrrNoI9iSWWSnafA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -281,6 +287,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.19.0 h1:hYz4ZVdUgjXTBUmrkrw55j1nHx68LfOKIQk5IYtyScg= github.com/rs/zerolog v1.19.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= +github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= +github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= @@ -344,6 +352,7 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1 github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -361,6 +370,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -385,6 +395,8 @@ golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -412,6 +424,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20201002202402-0a1ea396d57c/go.mod h1:iQL9McJNjoIa5mjH6nYTCTZXUN6RP+XW3eib7Ya3XcI= golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg= golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201024042810-be3efd7ff127 h1:pZPp9+iYUqwYKLjht0SDBbRCRK/9gAXDy7pz5fRDpjo= +golang.org/x/net v0.0.0-20201024042810-be3efd7ff127/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -424,6 +438,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -449,6 +464,8 @@ golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13 h1:5jaG59Zhd+8ZXe8C+lgiAGqkOaZBruqrWclLkgAww34= golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201024232916-9f70ab9862d5 h1:iCaAy5bMeEvwANu3YnJfWwI0kWAGkEa2RXPdweI/ysk= +golang.org/x/sys v0.0.0-20201024232916-9f70ab9862d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -479,7 +496,11 @@ golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200815165600-90abf76919f3 h1:0aScV/0rLmANzEYIhjCOi2pTvDyhZNduBUMD2q3iqs4= +golang.org/x/tools v0.0.0-20200815165600-90abf76919f3/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/api.go b/pkg/api.go index be344af5d..83b038ef9 100644 --- a/pkg/api.go +++ b/pkg/api.go @@ -40,17 +40,17 @@ func handleAPICommand(t Transport, dc *webrtc.DataChannel) { case videoHighQuality: sender.Mute(false) if sender.Type() != SimpleSenderType { - sender.SwitchSpatialLayer(3) + sender.SwitchSpatialLayer(2) } case videoMediumQuality: sender.Mute(false) if sender.Type() != SimpleSenderType { - sender.SwitchSpatialLayer(2) + sender.SwitchSpatialLayer(1) } case videoLowQuality: sender.Mute(false) if sender.Type() != SimpleSenderType { - sender.SwitchSpatialLayer(1) + sender.SwitchSpatialLayer(0) } case videoMuted: sender.Mute(true) diff --git a/pkg/buffer.go b/pkg/buffer.go index 76d7048bd..c5445c727 100644 --- a/pkg/buffer.go +++ b/pkg/buffer.go @@ -1,12 +1,9 @@ package sfu import ( - "math" - "sort" "sync" "time" - "github.com/gammazero/deque" log "github.com/pion/ion-log" "github.com/pion/rtcp" "github.com/pion/rtp" @@ -17,12 +14,9 @@ const ( maxSN = 1 << 16 // default buffer time by ms defaultBufferTime = 1000 -) -type rtpExtInfo struct { - ExtTSN uint32 - Timestamp int64 -} + reportDelta = 1e9 +) // Buffer contains all packets type Buffer struct { @@ -31,8 +25,11 @@ type Buffer struct { pktQueue queue codecType webrtc.RTPCodecType simulcast bool + mediaSSRC uint32 clockRate uint32 maxBitrate uint64 + lastReport int64 + twccExt uint8 // supported feedbacks remb bool @@ -47,7 +44,6 @@ type Buffer struct { lastExpected uint32 lastReceived uint32 lostRate float32 - ssrc uint32 lastPacketTime int64 // Time the last RTP packet from this source was received lastRtcpPacketTime int64 // Time the last RTCP packet was received. lastRtcpSrTime int64 // Time the last RTCP SR was received. Required for DLSR computation. @@ -56,20 +52,14 @@ type Buffer struct { maxSeqNo uint16 // The highest sequence number received in an RTP data packet jitter float64 // An estimate of the statistical variance of the RTP data packet inter-arrival time. totalByte uint64 - - // transport-cc - tccExt uint8 - tccExtInfo []rtpExtInfo - tccCycles uint32 - tccLastExtSN uint32 - tccPktCtn uint8 - tccLastSn uint16 - lastExtInfo uint16 + // callbacks + feedbackTWCC func(sn uint16, timeNS int64, marker bool) + feedbackCB func([]rtcp.Packet) } // BufferOptions provides configuration options for the buffer type BufferOptions struct { - TCCExt int + TWCCExt int BufferTime int MaxBitRate uint64 } @@ -77,18 +67,18 @@ type BufferOptions struct { // NewBuffer constructs a new Buffer func NewBuffer(track *webrtc.Track, o BufferOptions) *Buffer { b := &Buffer{ - ssrc: track.SSRC(), + mediaSSRC: track.SSRC(), clockRate: track.Codec().ClockRate, codecType: track.Codec().Type, maxBitrate: o.MaxBitRate, simulcast: len(track.RID()) > 0, + twccExt: uint8(o.TWCCExt), } if o.BufferTime <= 0 { o.BufferTime = defaultBufferTime } b.pktQueue.duration = uint32(o.BufferTime) * b.clockRate / 1000 b.pktQueue.ssrc = track.SSRC() - b.tccExt = uint8(o.TCCExt) for _, fb := range track.Codec().RTCPFeedback { switch fb.Type { @@ -97,7 +87,6 @@ func NewBuffer(track *webrtc.Track, o BufferOptions) *Buffer { b.remb = true case webrtc.TypeRTCPFBTransportCC: log.Debugf("Setting feedback %s", webrtc.TypeRTCPFBTransportCC) - b.tccExtInfo = make([]rtpExtInfo, 1<<8) b.tcc = true case webrtc.TypeRTCPFBNACK: log.Debugf("Setting feedback %s", webrtc.TypeRTCPFBNACK) @@ -108,15 +97,17 @@ func NewBuffer(track *webrtc.Track, o BufferOptions) *Buffer { return b } -// Push adds a RTP Packet, out of order, new packet may be arrived later -func (b *Buffer) Push(p *rtp.Packet) { +// push adds a RTP Packet, out of order, new packet may be arrived later +func (b *Buffer) push(p *rtp.Packet) { b.mu.Lock() defer b.mu.Unlock() + b.lastPacketTime = time.Now().UnixNano() b.totalByte += uint64(p.MarshalSize()) if b.packetCount == 0 { b.baseSN = p.SequenceNumber b.maxSeqNo = p.SequenceNumber b.pktQueue.headSN = p.SequenceNumber - 1 + b.lastReport = b.lastPacketTime } else if snDiff(b.maxSeqNo, p.SequenceNumber) <= 0 { if p.SequenceNumber < b.maxSeqNo { b.cycles += maxSN @@ -124,7 +115,6 @@ func (b *Buffer) Push(p *rtp.Packet) { b.maxSeqNo = p.SequenceNumber } b.packetCount++ - b.lastPacketTime = time.Now().UnixNano() arrival := uint32(b.lastPacketTime / 1e6 * int64(b.clockRate/1e3)) transit := arrival - p.Timestamp if b.lastTransit != 0 { @@ -141,16 +131,16 @@ func (b *Buffer) Push(p *rtp.Packet) { if b.tcc { rtpTCC := rtp.TransportCCExtension{} - if err := rtpTCC.Unmarshal(p.GetExtension(b.tccExt)); err == nil { - if rtpTCC.TransportSequence < 0x0fff && (b.tccLastSn&0xffff) > 0xf000 { - b.tccCycles += maxSN - } - b.tccExtInfo = append(b.tccExtInfo, rtpExtInfo{ - ExtTSN: b.tccCycles | uint32(rtpTCC.TransportSequence), - Timestamp: b.lastPacketTime / 1e3, - }) + if err := rtpTCC.Unmarshal(p.GetExtension(b.twccExt)); err == nil { + b.feedbackTWCC(rtpTCC.TransportSequence, b.lastPacketTime, p.Marker) } } + + if b.lastPacketTime-b.lastReport >= reportDelta { + b.feedbackCB(b.getRTCP()) + b.lastReport = b.lastPacketTime + } + } func (b *Buffer) buildREMBPacket() *rtcp.ReceiverEstimatedMaximumBitrate { @@ -170,187 +160,9 @@ func (b *Buffer) buildREMBPacket() *rtcp.ReceiverEstimatedMaximumBitrate { b.totalByte = 0 return &rtcp.ReceiverEstimatedMaximumBitrate{ - SenderSSRC: b.ssrc, - Bitrate: br, - SSRCs: []uint32{b.ssrc}, - } -} - -func (b *Buffer) buildTransportCCPacket() *rtcp.TransportLayerCC { - if len(b.tccExtInfo) == 0 { - return nil - } - sort.Slice(b.tccExtInfo, func(i, j int) bool { - return b.tccExtInfo[i].ExtTSN < b.tccExtInfo[j].ExtTSN - }) - tccPkts := make([]rtpExtInfo, 0, int(float64(len(b.tccExtInfo))*1.2)) - for _, tccExtInfo := range b.tccExtInfo { - if tccExtInfo.ExtTSN < b.tccLastExtSN { - continue - } - if b.tccLastExtSN != 0 { - for j := b.tccLastExtSN + 1; j < tccExtInfo.ExtTSN; j++ { - tccPkts = append(tccPkts, rtpExtInfo{ExtTSN: j}) - } - } - b.tccLastExtSN = tccExtInfo.ExtTSN - tccPkts = append(tccPkts, tccExtInfo) - } - b.tccExtInfo = b.tccExtInfo[:0] - - rtcpTCC := &rtcp.TransportLayerCC{ - Header: rtcp.Header{ - Padding: true, - Count: rtcp.FormatTCC, - Type: rtcp.TypeTransportSpecificFeedback, - }, - MediaSSRC: b.ssrc, - BaseSequenceNumber: uint16(tccPkts[0].ExtTSN), - PacketStatusCount: uint16(len(tccPkts)), - FbPktCount: b.tccPktCtn, - } - b.tccPktCtn++ - - firstRecv := false - allSame := true - timestamp := int64(0) - deltaLen := 0 - lastStatus := rtcp.TypeTCCPacketReceivedWithoutDelta - maxStatus := rtcp.TypeTCCPacketNotReceived - - var statusList deque.Deque - - for _, stat := range tccPkts { - status := rtcp.TypeTCCPacketNotReceived - if stat.Timestamp != 0 { - var delta int64 - if !firstRecv { - firstRecv = true - timestamp = stat.Timestamp - rtcpTCC.ReferenceTime = uint32(stat.Timestamp / 64000) - } - - delta = (stat.Timestamp - timestamp) / 250 - if delta < 0 || delta > 255 { - status = rtcp.TypeTCCPacketReceivedLargeDelta - rDelta := int16(delta) - if int64(rDelta) != delta { - if rDelta > 0 { - rDelta = math.MaxInt16 - } else { - rDelta = math.MinInt16 - } - } - rtcpTCC.RecvDeltas = append(rtcpTCC.RecvDeltas, &rtcp.RecvDelta{ - Type: status, - Delta: int64(rDelta) * 250, - }) - deltaLen += 2 - } else { - status = rtcp.TypeTCCPacketReceivedSmallDelta - rtcpTCC.RecvDeltas = append(rtcpTCC.RecvDeltas, &rtcp.RecvDelta{ - Type: status, - Delta: delta * 250, - }) - deltaLen++ - } - timestamp = stat.Timestamp - } - - if allSame && lastStatus != rtcp.TypeTCCPacketReceivedWithoutDelta && status != lastStatus { - if statusList.Len() > 7 { - rtcpTCC.PacketChunks = append(rtcpTCC.PacketChunks, &rtcp.RunLengthChunk{ - PacketStatusSymbol: lastStatus, - RunLength: uint16(statusList.Len()), - }) - statusList.Clear() - lastStatus = rtcp.TypeTCCPacketReceivedWithoutDelta - maxStatus = rtcp.TypeTCCPacketNotReceived - allSame = true - } else { - allSame = false - } - } - statusList.PushBack(status) - if status > maxStatus { - maxStatus = status - } - lastStatus = status - - if !allSame { - if maxStatus == rtcp.TypeTCCPacketReceivedLargeDelta && statusList.Len() > 6 { - symbolList := make([]uint16, 7) - for i := 0; i < 7; i++ { - symbolList[i] = statusList.PopFront().(uint16) - } - rtcpTCC.PacketChunks = append(rtcpTCC.PacketChunks, &rtcp.StatusVectorChunk{ - SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, - SymbolList: symbolList, - }) - lastStatus = rtcp.TypeTCCPacketReceivedWithoutDelta - maxStatus = rtcp.TypeTCCPacketNotReceived - allSame = true - - for i := 0; i < statusList.Len(); i++ { - status = statusList.At(i).(uint16) - if status > maxStatus { - maxStatus = status - } - if allSame && lastStatus != rtcp.TypeTCCPacketReceivedWithoutDelta && status != lastStatus { - allSame = false - } - lastStatus = status - } - } else if statusList.Len() > 13 { - symbolList := make([]uint16, 14) - for i := 0; i < 14; i++ { - symbolList[i] = statusList.PopFront().(uint16) - } - rtcpTCC.PacketChunks = append(rtcpTCC.PacketChunks, &rtcp.StatusVectorChunk{ - SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, - SymbolList: symbolList, - }) - lastStatus = rtcp.TypeTCCPacketReceivedWithoutDelta - maxStatus = rtcp.TypeTCCPacketNotReceived - allSame = true - } - } - } - - if statusList.Len() > 0 { - if allSame { - rtcpTCC.PacketChunks = append(rtcpTCC.PacketChunks, &rtcp.RunLengthChunk{ - PacketStatusSymbol: lastStatus, - RunLength: uint16(statusList.Len()), - }) - } else if maxStatus == rtcp.TypeTCCPacketReceivedLargeDelta { - symbolList := make([]uint16, statusList.Len()) - for i := 0; i < statusList.Len(); i++ { - symbolList[i] = statusList.PopFront().(uint16) - } - rtcpTCC.PacketChunks = append(rtcpTCC.PacketChunks, &rtcp.StatusVectorChunk{ - SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, - SymbolList: symbolList, - }) - } else { - symbolList := make([]uint16, statusList.Len()) - for i := 0; i < statusList.Len(); i++ { - symbolList[i] = statusList.PopFront().(uint16) - } - rtcpTCC.PacketChunks = append(rtcpTCC.PacketChunks, &rtcp.StatusVectorChunk{ - SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, - SymbolList: symbolList, - }) - } - } - - pLen := uint16(20 + len(rtcpTCC.PacketChunks)*2 + deltaLen) - rtcpTCC.Header.Padding = pLen%4 != 0 - for pLen%4 != 0 { - pLen++ + Bitrate: br, + SSRCs: []uint32{b.mediaSSRC}, } - rtcpTCC.Header.Length = (pLen / 4) - 1 - return rtcpTCC } func (b *Buffer) buildReceptionReport() rtcp.ReceptionReport { @@ -381,7 +193,7 @@ func (b *Buffer) buildReceptionReport() rtcp.ReceptionReport { } rr := rtcp.ReceptionReport{ - SSRC: b.ssrc, + SSRC: b.mediaSSRC, FractionLost: fracLost, TotalLost: lost, LastSequenceNumber: extMaxSeq, @@ -400,25 +212,18 @@ func (b *Buffer) setSenderReportData(rtpTime uint32, ntpTime uint64) { b.lastSRRecv = time.Now().UnixNano() } -func (b *Buffer) getRTCP() (rtcp.ReceptionReport, []rtcp.Packet) { - b.mu.RLock() - defer b.mu.RUnlock() +func (b *Buffer) getRTCP() []rtcp.Packet { var pkts []rtcp.Packet - var report rtcp.ReceptionReport - report = b.buildReceptionReport() + pkts = append(pkts, &rtcp.ReceiverReport{ + Reports: []rtcp.ReceptionReport{b.buildReceptionReport()}, + }) - if b.remb { + if b.remb && !b.tcc { pkts = append(pkts, b.buildREMBPacket()) } - if b.tcc { - if tccPkt := b.buildTransportCCPacket(); tccPkt != nil { - pkts = append(pkts, tccPkt) - } - } - - return report, pkts + return pkts } // WritePacket write buffer packet to requested track. and modify headers @@ -427,10 +232,13 @@ func (b *Buffer) WritePacket(sn uint16, track *webrtc.Track, snOffset uint16, ts defer b.mu.RUnlock() if bufferPkt := b.pktQueue.GetPacket(sn); bufferPkt != nil { bSsrc := bufferPkt.SSRC + bPT := bufferPkt.PayloadType + bufferPkt.PayloadType = track.PayloadType() bufferPkt.SequenceNumber -= snOffset bufferPkt.Timestamp -= tsOffset bufferPkt.SSRC = ssrc err := track.WriteRTP(bufferPkt) + bufferPkt.PayloadType = bPT bufferPkt.Timestamp += tsOffset bufferPkt.SequenceNumber += snOffset bufferPkt.SSRC = bSsrc @@ -444,3 +252,11 @@ func (b *Buffer) onLostHandler(fn func(nack *rtcp.TransportLayerNack)) { b.pktQueue.onLost = fn } } + +func (b *Buffer) onFeedback(fn func(fb []rtcp.Packet)) { + b.feedbackCB = fn +} + +func (b *Buffer) onTransportWideCC(fn func(sn uint16, timeNS int64, marker bool)) { + b.feedbackTWCC = fn +} diff --git a/pkg/helpers.go b/pkg/helpers.go index 5a170033d..1feffec47 100644 --- a/pkg/helpers.go +++ b/pkg/helpers.go @@ -163,9 +163,12 @@ func timeToNtp(ns int64) uint64 { return seconds<<32 | fraction } -// fromNtp converts a NTP timestamp into GO time -func fromNtp(seconds, fraction uint32) (tm int64) { - n := (int64(fraction) * 1e9) >> 32 - tm = (int64(seconds)-ntpEpoch)*1e9 + n - return +// setNBitsOfUint16 will truncate the value to size, left-shift to startIndex position and set +func setNBitsOfUint16(src, size, startIndex, val uint16) uint16 { + if startIndex+size > 16 { + return 0 + } + // truncate val to size bits + val &= (1 << size) - 1 + return src | (val << (16 - size - startIndex)) } diff --git a/pkg/mediaengine.go b/pkg/mediaengine.go index 37f619fe3..455928d20 100644 --- a/pkg/mediaengine.go +++ b/pkg/mediaengine.go @@ -25,7 +25,6 @@ var ( // MediaEngine handles stream codecs type MediaEngine struct { webrtc.MediaEngine - tCCExt int } // PopulateFromSDP finds all codecs in sd and adds them to m, using the dynamic @@ -44,13 +43,6 @@ func (e *MediaEngine) PopulateFromSDP(sd webrtc.SessionDescription) error { continue } - for _, att := range md.Attributes { - if att.Key == sdp.AttrKeyExtMap && strings.HasSuffix(att.Value, sdp.TransportCCURI) { - e.tCCExt, _ = strconv.Atoi(att.Value[:1]) - break - } - } - for _, format := range md.MediaName.Formats { pt, err := strconv.Atoi(format) if err != nil { diff --git a/pkg/receiver.go b/pkg/receiver.go index b14c60c31..b6dfd5e30 100644 --- a/pkg/receiver.go +++ b/pkg/receiver.go @@ -17,20 +17,17 @@ const ( maxSize = 1024 ) -type ReceiverConfig struct { - RouterConfig - tccExt int -} - // Receiver defines a interface for a track receivers type Receiver interface { + Start() Track() *webrtc.Track AddSender(sender Sender) DeleteSender(pid string) SpatialLayer() uint8 - GetRTCP() (rtcp.ReceptionReport, []rtcp.Packet) OnCloseHandler(fn func()) - OnLostHandler(fn func(nack *rtcp.TransportLayerNack)) + OnTransportWideCC(fn func(sn uint16, timeNS int64, marker bool)) + SendRTCP(p []rtcp.Packet) + SetRTCPCh(ch chan []rtcp.Packet) WriteBufferedPacket(sn uint16, track *webrtc.Track, snOffset uint16, tsOffset, ssrc uint32) error Close() } @@ -45,6 +42,7 @@ type WebRTCReceiver struct { buffer *Buffer bandwidth uint64 rtpCh chan *rtp.Packet + rtcpCh chan []rtcp.Packet senders map[string]Sender onCloseHandler func() @@ -52,7 +50,7 @@ type WebRTCReceiver struct { } // NewWebRTCReceiver creates a new webrtc track receivers -func NewWebRTCReceiver(ctx context.Context, receiver *webrtc.RTPReceiver, track *webrtc.Track, config ReceiverConfig) Receiver { +func NewWebRTCReceiver(ctx context.Context, receiver *webrtc.RTPReceiver, track *webrtc.Track, config BufferOptions) Receiver { ctx, cancel := context.WithCancel(ctx) w := &WebRTCReceiver{ @@ -66,21 +64,30 @@ func NewWebRTCReceiver(ctx context.Context, receiver *webrtc.RTPReceiver, track switch w.track.RID() { case quarterResolution: - w.spatialLayer = 1 + w.spatialLayer = 0 case halfResolution: - w.spatialLayer = 2 + w.spatialLayer = 1 case fullResolution: - w.spatialLayer = 3 + w.spatialLayer = 2 default: w.spatialLayer = 0 } - w.buffer = NewBuffer(track, BufferOptions{ - BufferTime: config.MaxBufferTime, - MaxBitRate: config.MaxBandwidth * 1000, - TCCExt: config.tccExt, + w.buffer = NewBuffer(track, config) + + w.buffer.onFeedback(func(packets []rtcp.Packet) { + w.rtcpCh <- packets }) + w.buffer.onLostHandler(func(nack *rtcp.TransportLayerNack) { + log.Debugf("Writing nack to mediaSSRC: %d, missing sn: %d, bitmap: %b", track.SSRC(), nack.Nacks[0].PacketID, nack.Nacks[0].LostPackets) + w.rtcpCh <- []rtcp.Packet{nack} + }) + + return w +} + +func (w *WebRTCReceiver) Start() { go w.readRTP() if len(w.track.RID()) > 0 { go w.readSimulcastRTCP(w.track.RID()) @@ -88,8 +95,6 @@ func NewWebRTCReceiver(ctx context.Context, receiver *webrtc.RTPReceiver, track go w.readRTCP() } go w.writeRTP() - - return w } // OnCloseHandler method to be called on remote tracked removed @@ -97,8 +102,8 @@ func (w *WebRTCReceiver) OnCloseHandler(fn func()) { w.onCloseHandler = fn } -func (w *WebRTCReceiver) OnLostHandler(fn func(nack *rtcp.TransportLayerNack)) { - w.buffer.onLostHandler(fn) +func (w *WebRTCReceiver) OnTransportWideCC(fn func(sn uint16, timeNS int64, marker bool)) { + w.buffer.onTransportWideCC(fn) } func (w *WebRTCReceiver) AddSender(sender Sender) { @@ -114,6 +119,10 @@ func (w *WebRTCReceiver) DeleteSender(pid string) { delete(w.senders, pid) } +func (w *WebRTCReceiver) SendRTCP(p []rtcp.Packet) { + w.rtcpCh <- p +} + func (w *WebRTCReceiver) SpatialLayer() uint8 { return w.spatialLayer } @@ -123,10 +132,6 @@ func (w *WebRTCReceiver) Track() *webrtc.Track { return w.track } -func (w *WebRTCReceiver) GetRTCP() (rtcp.ReceptionReport, []rtcp.Packet) { - return w.buffer.getRTCP() -} - // WriteBufferedPacket writes buffered packet to track, return error if packet not found func (w *WebRTCReceiver) WriteBufferedPacket(sn uint16, track *webrtc.Track, snOffset uint16, tsOffset, ssrc uint32) error { if w.buffer == nil || w.ctx.Err() != nil { @@ -135,6 +140,10 @@ func (w *WebRTCReceiver) WriteBufferedPacket(sn uint16, track *webrtc.Track, snO return w.buffer.WritePacket(sn, track, snOffset, tsOffset, ssrc) } +func (w *WebRTCReceiver) SetRTCPCh(ch chan []rtcp.Packet) { + w.rtcpCh = ch +} + // Close gracefully close the track func (w *WebRTCReceiver) Close() { if w.ctx.Err() != nil { @@ -167,7 +176,7 @@ func (w *WebRTCReceiver) readRTP() { continue } - w.buffer.Push(pkt) + w.buffer.push(pkt) select { case <-w.ctx.Done(): @@ -181,7 +190,7 @@ func (w *WebRTCReceiver) readRTP() { func (w *WebRTCReceiver) readRTCP() { for { pkts, err := w.receiver.ReadRTCP() - if err == io.ErrClosedPipe || w.ctx.Err() != nil { + if err == io.ErrClosedPipe || err == io.EOF || w.ctx.Err() != nil { return } if err != nil { diff --git a/pkg/router.go b/pkg/router.go index 6a63d21d0..66c1d50ac 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -3,31 +3,32 @@ package sfu //go:generate go run github.com/matryer/moq -out router_mock_test.generated.go . Router import ( + "context" "math/rand" "sync" + "time" - log "github.com/pion/ion-log" + "github.com/pion/webrtc/v3" + log "github.com/pion/ion-log" "github.com/pion/rtcp" ) const ( - SimpleRouter = iota + 1 - SimulcastRouter - SVCRouter + SimpleReceiver = iota + 1 + SimulcastReceiver + SVCReceiver ) // Router defines a track rtp/rtcp router type Router interface { ID() string - Kind() int Config() RouterConfig - AddReceiver(recv Receiver) - GetReceiver(layer uint8) Receiver + AddReceiver(ctx context.Context, track *webrtc.Track, receiver *webrtc.RTPReceiver) AddSender(p *WebRTCTransport) error - GetRTCP() []rtcp.Packet - SendRTCP(pkts []rtcp.Packet) error - SwitchSpatialLayer(targetLayer uint8, sub Sender) bool + AddTWCCExt(id string, ext int) + SendRTCP(pkts []rtcp.Packet) + Stop() } // RouterConfig defines router configurations @@ -37,151 +38,174 @@ type RouterConfig struct { Simulcast SimulcastConfig `mapstructure:"simulcast"` } +type receiverRouter struct { + kind int + stream string + receivers [3]Receiver +} + type router struct { + id string mu sync.RWMutex - peer *WebRTCTransport - kind int + peer *webrtc.PeerConnection + twcc *TransportWideCC + rtcpCh chan []rtcp.Packet config RouterConfig - streamID string - receivers [3 + 1]Receiver + twccExts map[string]int + receivers map[string]*receiverRouter } // newRouter for routing rtp/rtcp packets -func newRouter(peer *WebRTCTransport, streamID string, config RouterConfig, kind int) Router { - return &router{ - peer: peer, - kind: kind, - config: config, - streamID: streamID, +func newRouter(peer *webrtc.PeerConnection, id string, config RouterConfig) Router { + ch := make(chan []rtcp.Packet, 10) + r := &router{ + id: id, + peer: peer, + twcc: newTransportWideCC(ch), + config: config, + rtcpCh: ch, + twccExts: make(map[string]int), + receivers: make(map[string]*receiverRouter), } + go r.sendRTCP() + return r } func (r *router) ID() string { - return r.peer.id -} - -func (r *router) Kind() int { - return r.kind + return r.id } func (r *router) Config() RouterConfig { return r.config } -func (r *router) AddReceiver(recv Receiver) { +func (r *router) AddReceiver(ctx context.Context, track *webrtc.Track, receiver *webrtc.RTPReceiver) { r.mu.Lock() defer r.mu.Unlock() - r.receivers[recv.SpatialLayer()] = recv -} -func (r *router) GetReceiver(layer uint8) Receiver { - r.mu.RLock() - defer r.mu.RUnlock() - return r.receivers[layer] -} - -// AddWebRTCSender to router -func (r *router) AddSender(p *WebRTCTransport) error { - r.mu.RLock() - defer r.mu.RUnlock() - var ( - recv Receiver - sender Sender - ssrc uint32 - ) - - if r.kind == SimpleRouter { - recv = r.receivers[0] - ssrc = recv.Track().SSRC() - } else { - for _, rcv := range r.receivers { - if rcv != nil { - recv = rcv - } - if !r.config.Simulcast.BestQualityFirst && rcv != nil { - break - } - } - ssrc = rand.Uint32() + trackID := track.ID() + recv := NewWebRTCReceiver(ctx, receiver, track, BufferOptions{ + BufferTime: r.config.MaxBufferTime, + MaxBitRate: r.config.MaxBandwidth * 1000, + TWCCExt: r.twccExts[trackID], + }) + recv.OnTransportWideCC(func(sn uint16, timeNS int64, marker bool) { + r.twcc.push(sn, timeNS, marker) + }) + recv.SetRTCPCh(r.rtcpCh) + recv.OnCloseHandler(func() { + r.deleteReceiver(trackID) + }) + if track.Kind() == webrtc.RTPCodecTypeVideo { + r.twcc.mSSRC = track.SSRC() + r.twcc.tccLastReport = time.Now().UnixNano() } + recv.Start() - if recv == nil { - return errNoReceiverFound + if rr, ok := r.receivers[trackID]; ok { + rr.receivers[recv.SpatialLayer()] = recv + return } - inTrack := recv.Track() - to := p.me.GetCodecsByName(recv.Track().Codec().Name) - if len(to) == 0 { - return errPtNotSupported - } - pt := to[0].PayloadType - outTrack, err := p.pc.NewTrack(pt, ssrc, inTrack.ID(), inTrack.Label()) - if err != nil { - return err + rr := &receiverRouter{ + stream: track.Label(), + receivers: [3]Receiver{}, } - // Create webrtc sender for the peer we are sending track to - s, err := p.pc.AddTrack(outTrack) - if err != nil { - return err - } - if r.kind == SimulcastRouter { - sender = NewSimulcastSender(p.ctx, p.id, r, s, recv.SpatialLayer()) + rr.receivers[recv.SpatialLayer()] = recv + + if len(track.RID()) > 0 { + rr.kind = SimulcastReceiver } else { - sender = NewSimpleSender(p.ctx, p.id, r, s) + rr.kind = SimpleReceiver } - sender.OnCloseHandler(func() { - if err := p.pc.RemoveTrack(s); err != nil { - log.Errorf("Error closing sender: %s", err) - } - }) - p.AddSender(r.streamID, sender) - recv.AddSender(sender) - return nil -} -func (r *router) SendRTCP(pkts []rtcp.Packet) error { - return r.peer.pc.WriteRTCP(pkts) + r.receivers[trackID] = rr } -func (r *router) GetRTCP() []rtcp.Packet { +// AddWebRTCSender to router +func (r *router) AddSender(p *WebRTCTransport) error { r.mu.RLock() defer r.mu.RUnlock() - if r.kind == SimpleRouter || r.kind == SVCRouter { - if r.receivers[0] != nil { - rr, ps := r.receivers[0].GetRTCP() - if rr.SSRC != 0 { - ps = append(ps, &rtcp.ReceiverReport{ - Reports: []rtcp.ReceptionReport{rr}, - }) + for _, rr := range r.receivers { + var ( + recv Receiver + sender Sender + ssrc uint32 + ) + + if rr.kind == SimpleReceiver { + recv = rr.receivers[0] + ssrc = recv.Track().SSRC() + } else { + for _, rcv := range rr.receivers { + if rcv != nil { + recv = rcv + } + if !r.config.Simulcast.BestQualityFirst && rcv != nil { + break + } } - return ps + ssrc = rand.Uint32() } - return nil - } - var rtcpPkts []rtcp.Packet - var rReports []rtcp.ReceptionReport - for _, recv := range r.receivers { - if recv != nil { - rr, ps := recv.GetRTCP() - rtcpPkts = append(rtcpPkts, ps...) - if rr.SSRC != 0 { - rReports = append(rReports, rr) - } + + if recv == nil { + return errNoReceiverFound } - } - if len(rReports) > 0 { - rtcpPkts = append(rtcpPkts, &rtcp.ReceiverReport{ - Reports: rReports, + + inTrack := recv.Track() + to := p.me.GetCodecsByName(recv.Track().Codec().Name) + if len(to) == 0 { + return errPtNotSupported + } + pt := to[0].PayloadType + outTrack, err := p.pc.NewTrack(pt, ssrc, inTrack.ID(), inTrack.Label()) + if err != nil { + return err + } + // Create webrtc sender for the peer we are sending track to + s, err := p.pc.AddTrack(outTrack) + if err != nil { + return err + } + if rr.kind == SimulcastReceiver { + sender = NewSimulcastSender(p.ctx, p.id, rr, s, recv.SpatialLayer(), r.config.Simulcast) + } else { + sender = NewSimpleSender(p.ctx, p.id, rr, s) + } + sender.OnCloseHandler(func() { + if err := p.pc.RemoveTrack(s); err != nil { + log.Errorf("Error closing sender: %s", err) + } }) + p.AddSender(rr.stream, sender) + recv.AddSender(sender) } - return rtcpPkts + return nil +} + +func (r *router) AddTWCCExt(id string, ext int) { + r.twccExts[id] = ext } -func (r *router) SwitchSpatialLayer(targetLayer uint8, sub Sender) bool { - if targetRecv := r.GetReceiver(targetLayer); targetRecv != nil { - targetRecv.AddSender(sub) - return true +func (r *router) SendRTCP(pkts []rtcp.Packet) { + r.rtcpCh <- pkts +} + +func (r *router) Stop() { + close(r.rtcpCh) +} + +func (r *router) deleteReceiver(track string) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.receivers, track) +} + +func (r *router) sendRTCP() { + for pkts := range r.rtcpCh { + if err := r.peer.WriteRTCP(pkts); err != nil { + log.Errorf("Write rtcp to peer %s err :%v", r.id, err) + } } - return false } diff --git a/pkg/session.go b/pkg/session.go index 072a06c5b..08d60a1f5 100644 --- a/pkg/session.go +++ b/pkg/session.go @@ -54,7 +54,7 @@ func (r *Session) AddRouter(router Router) { continue } - log.Infof("AddRouter ssrc to %s", tid) + log.Infof("AddRouter mediaSSRC to %s", tid) if t, ok := t.(*WebRTCTransport); ok { if err := router.AddSender(t); err != nil { diff --git a/pkg/sfu.go b/pkg/sfu.go index 6f234e632..fa82d3672 100644 --- a/pkg/sfu.go +++ b/pkg/sfu.go @@ -53,8 +53,8 @@ func NewSFU(c Config) *SFU { // Configure required extensions sdes, _ := url.Parse(sdp.SDESRTPStreamIDURI) sdedMid, _ := url.Parse(sdp.SDESMidURI) - // transportCCURL, _ := url.Parse(sdp.TransportCCURI) - // rtcpfb = append(rtcpfb, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}) + transportCCURL, _ := url.Parse(sdp.TransportCCURI) + rtcpfb = append(rtcpfb, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}) rtcpfb = append(rtcpfb, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBGoogREMB}) se := webrtc.SettingEngine{} se.AddSDPExtensions(webrtc.SDPSectionVideo, @@ -65,9 +65,9 @@ func NewSFU(c Config) *SFU { { URI: sdedMid, }, - // { - // URI: transportCCURL, - // }, + { + URI: transportCCURL, + }, }) w := WebRTCTransportConfig{ diff --git a/pkg/simplesender.go b/pkg/simplesender.go index df11802be..cbca6e88a 100644 --- a/pkg/simplesender.go +++ b/pkg/simplesender.go @@ -21,7 +21,7 @@ type SimpleSender struct { cancel context.CancelFunc sender *webrtc.RTPSender track *webrtc.Track - router Router + router *receiverRouter enabled atomicBool payload uint8 maxBitrate uint64 @@ -40,7 +40,7 @@ type SimpleSender struct { } // NewSimpleSender creates a new track sender instance -func NewSimpleSender(ctx context.Context, id string, router Router, sender *webrtc.RTPSender) Sender { +func NewSimpleSender(ctx context.Context, id string, router *receiverRouter, sender *webrtc.RTPSender) Sender { ctx, cancel := context.WithCancel(ctx) s := &SimpleSender{ id: id, @@ -77,15 +77,14 @@ func (s *SimpleSender) WriteRTP(pkt *rtp.Packet) { if s.track.Kind() == webrtc.RTPCodecTypeVideo { // Forward pli to request a keyframe at max 1 pli per second if time.Now().Sub(s.lastPli) > time.Second { - recv := s.router.GetReceiver(0) + recv := s.router.receivers[0] if recv == nil { return } - if err := s.router.SendRTCP([]rtcp.Packet{ + recv.SendRTCP([]rtcp.Packet{ &rtcp.PictureLossIndication{SenderSSRC: pkt.SSRC, MediaSSRC: pkt.SSRC}, - }); err == nil { - s.lastPli = time.Now() - } + }) + s.lastPli = time.Now() } relay := false // Wait for a keyframe to sync new source @@ -190,7 +189,7 @@ func (s *SimpleSender) receiveRTCP() { pkts, err := s.sender.ReadRTCP() if err == io.ErrClosedPipe { // Remove sender from receiver - if recv := s.router.GetReceiver(0); recv != nil { + if recv := s.router.receivers[0]; recv != nil { recv.DeleteSender(s.id) } s.Close() @@ -205,7 +204,7 @@ func (s *SimpleSender) receiveRTCP() { log.Errorf("rtcp err => %v", err) } - recv := s.router.GetReceiver(0) + recv := s.router.receivers[0] if recv == nil { continue } @@ -234,9 +233,7 @@ func (s *SimpleSender) receiveRTCP() { } } if len(fwdPkts) > 0 { - if err := s.router.SendRTCP(fwdPkts); err != nil { - log.Errorf("Forwarding rtcp from sender err: %v", err) - } + recv.SendRTCP(fwdPkts) } } } diff --git a/pkg/simplesender_test.go b/pkg/simplesender_test.go index f48368b0f..1ebfb116d 100644 --- a/pkg/simplesender_test.go +++ b/pkg/simplesender_test.go @@ -26,7 +26,7 @@ func TestNewSimpleSender(t *testing.T) { type args struct { ctx context.Context id string - router Router + router *receiverRouter sender *webrtc.RTPSender } tests := []struct { @@ -157,17 +157,10 @@ func TestSimpleSender_receiveRTCP(t *testing.T) { fakeReceiver := &ReceiverMock{ DeleteSenderFunc: func(_ string) { }, - } - - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return fakeReceiver - }, - SendRTCPFunc: func(pkts []rtcp.Packet) error { - for _, pkt := range pkts { - gotRTCP <- pkt + SendRTCPFunc: func(p []rtcp.Packet) { + for _, pp := range p { + gotRTCP <- pp } - return nil }, } @@ -203,7 +196,6 @@ forLoop: MediaSSRC: 1234, }, }, - // TODO: Add test cases. } for _, tt := range tests { tt := tt @@ -212,7 +204,11 @@ forLoop: wss := &SimpleSender{ ctx: ctx, cancel: cancel, - router: fakeRouter, + router: &receiverRouter{ + kind: SimpleReceiver, + stream: "123", + receivers: [3]Receiver{fakeReceiver}, + }, sender: s, track: senderTrack, } @@ -250,16 +246,11 @@ forLoop: func TestSimpleSender_Close(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) closeCtr := 0 - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return nil - }, - } type fields struct { ctx context.Context cancel context.CancelFunc - router Router + router *receiverRouter onCloseHandler func() } tests := []struct { @@ -272,7 +263,7 @@ func TestSimpleSender_Close(t *testing.T) { fields: fields{ ctx: ctx, cancel: cancel, - router: fakeRouter, + router: nil, onCloseHandler: nil, }, }, @@ -282,7 +273,7 @@ func TestSimpleSender_Close(t *testing.T) { fields: fields{ ctx: ctx, cancel: cancel, - router: fakeRouter, + router: nil, onCloseHandler: func() { closeCtr++ }, @@ -458,26 +449,25 @@ forLoop: } gotPli := make(chan struct{}, 1) - fakeRecv := &ReceiverMock{} - - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return fakeRecv - }, - SendRTCPFunc: func(pkts []rtcp.Packet) error { - for _, pkt := range pkts { - if _, ok := pkt.(*rtcp.PictureLossIndication); ok { + fakeRecv := &ReceiverMock{ + SendRTCPFunc: func(p []rtcp.Packet) { + for _, pp := range p { + if _, ok := pp.(*rtcp.PictureLossIndication); ok { gotPli <- struct{}{} } } - return nil }, } + r := &receiverRouter{ + kind: SimpleReceiver, + receivers: [3]Receiver{fakeRecv}, + } + simpleSdr := SimpleSender{ ctx: context.Background(), enabled: atomicBool{1}, - router: fakeRouter, + router: r, track: senderTrack, payload: senderTrack.PayloadType(), } diff --git a/pkg/simulcastsender.go b/pkg/simulcastsender.go index d2a60f66b..3738501bc 100644 --- a/pkg/simulcastsender.go +++ b/pkg/simulcastsender.go @@ -20,7 +20,7 @@ type SimulcastSender struct { id string ctx context.Context cancel context.CancelFunc - router Router + router *receiverRouter sender *webrtc.RTPSender track *webrtc.Track enabled atomicBool @@ -55,7 +55,7 @@ type SimulcastSender struct { } // NewSimulcastSender creates a new track sender instance -func NewSimulcastSender(ctx context.Context, id string, router Router, sender *webrtc.RTPSender, layer uint8) Sender { +func NewSimulcastSender(ctx context.Context, id string, router *receiverRouter, sender *webrtc.RTPSender, layer uint8, conf SimulcastConfig) Sender { ctx, cancel := context.WithCancel(ctx) s := &SimulcastSender{ id: id, @@ -70,7 +70,7 @@ func NewSimulcastSender(ctx context.Context, id string, router Router, sender *w currentSpatialLayer: layer, targetSpatialLayer: layer, simulcastSSRC: sender.Track().SSRC(), - temporalEnabled: router.Config().Simulcast.EnableTemporalLayer, + temporalEnabled: conf.EnableTemporalLayer, refPicID: uint16(rand.Uint32()), refTlzi: uint8(rand.Uint32()), } @@ -98,17 +98,17 @@ func (s *SimulcastSender) WriteRTP(pkt *rtp.Packet) { // Check if packet SSRC is different from before // if true, the video source changed if s.lSSRC != pkt.SSRC { - recv := s.router.GetReceiver(s.targetSpatialLayer) + recv := s.router.receivers[s.targetSpatialLayer] if recv == nil || recv.Track().SSRC() != pkt.SSRC { return } // Forward pli to request a keyframe at max 1 pli per second if time.Now().Sub(s.lastPli) > time.Second { - if err := s.router.SendRTCP([]rtcp.Packet{ + recv.SendRTCP([]rtcp.Packet{ &rtcp.PictureLossIndication{SenderSSRC: pkt.SSRC, MediaSSRC: pkt.SSRC}, - }); err == nil { - s.lastPli = time.Now() - } + }) + s.lastPli = time.Now() + } relay := false // Wait for a keyframe to sync new source @@ -143,7 +143,7 @@ func (s *SimulcastSender) WriteRTP(pkt *rtp.Packet) { } // Switch is done remove sender from previous layer // and update current layer - if pRecv := s.router.GetReceiver(s.currentSpatialLayer); pRecv != nil && s.currentSpatialLayer != s.targetSpatialLayer { + if pRecv := s.router.receivers[s.currentSpatialLayer]; pRecv != nil && s.currentSpatialLayer != s.targetSpatialLayer { pRecv.DeleteSender(s.id) } s.currentSpatialLayer = s.targetSpatialLayer @@ -211,7 +211,8 @@ func (s *SimulcastSender) SwitchSpatialLayer(targetLayer uint8) { if s.currentSpatialLayer != s.targetSpatialLayer { return } - if ok := s.router.SwitchSpatialLayer(targetLayer, s); ok { + if recv := s.router.receivers[targetLayer]; recv != nil { + recv.AddSender(s) s.targetSpatialLayer = targetLayer } } @@ -234,7 +235,7 @@ func (s *SimulcastSender) Mute(val bool) { } s.enabled.set(!val) if !val { - // reset last ssrc to force a re-sync + // reset last mediaSSRC to force a re-sync s.lSSRC = 0 } } @@ -267,7 +268,7 @@ func (s *SimulcastSender) receiveRTCP() { pkts, err := s.sender.ReadRTCP() if err == io.ErrClosedPipe { // Remove sender from receiver - if recv := s.router.GetReceiver(s.currentSpatialLayer); recv != nil { + if recv := s.router.receivers[s.currentSpatialLayer]; recv != nil { recv.DeleteSender(s.id) } s.Close() @@ -283,7 +284,7 @@ func (s *SimulcastSender) receiveRTCP() { continue } - recv := s.router.GetReceiver(s.currentSpatialLayer) + recv := s.router.receivers[s.currentSpatialLayer] if recv == nil { continue } @@ -318,9 +319,7 @@ func (s *SimulcastSender) receiveRTCP() { } } if len(fwdPkts) > 0 { - if err := s.router.SendRTCP(fwdPkts); err != nil { - log.Errorf("Forwarding rtcp from sender err: %v", err) - } + recv.SendRTCP(fwdPkts) } } } diff --git a/pkg/simulcastsender_test.go b/pkg/simulcastsender_test.go index fe5d0a393..a8e7572d4 100644 --- a/pkg/simulcastsender_test.go +++ b/pkg/simulcastsender_test.go @@ -28,7 +28,7 @@ func TestNewWebRTCSimulcastSender(t *testing.T) { type args struct { ctx context.Context id string - router Router + router *receiverRouter sender *webrtc.RTPSender layer uint8 } @@ -39,13 +39,9 @@ func TestNewWebRTCSimulcastSender(t *testing.T) { { name: "Must return a non nil Sender", args: args{ - ctx: ctx, - id: "test", - router: &RouterMock{ - ConfigFunc: func() RouterConfig { - return RouterConfig{} - }, - }, + ctx: ctx, + id: "test", + router: nil, sender: sender, layer: 2, }, @@ -54,7 +50,7 @@ func TestNewWebRTCSimulcastSender(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - got := NewSimulcastSender(tt.args.ctx, tt.args.id, tt.args.router, tt.args.sender, tt.args.layer) + got := NewSimulcastSender(tt.args.ctx, tt.args.id, tt.args.router, tt.args.sender, tt.args.layer, SimulcastConfig{}) assert.NotNil(t, got) }) } @@ -94,22 +90,20 @@ func TestSimulcastSender_WriteRTP(t *testing.T) { TrackFunc: func() *webrtc.Track { return fakeRecvTrack }, - } - - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return fakeReceiver - }, - SendRTCPFunc: func(pkts []rtcp.Packet) error { - for _, pkt := range pkts { - if _, ok := pkt.(*rtcp.PictureLossIndication); ok { + SendRTCPFunc: func(p []rtcp.Packet) { + for _, pp := range p { + if _, ok := pp.(*rtcp.PictureLossIndication); ok { gotPli <- struct{}{} } } - return nil }, } + r := &receiverRouter{ + kind: SimulcastReceiver, + receivers: [3]Receiver{fakeReceiver, fakeReceiver, fakeReceiver}, + } + err = signalPair(sfu, remote) assert.NoError(t, err) @@ -161,7 +155,7 @@ forLoop: s := &SimulcastSender{ ctx: context.Background(), enabled: atomicBool{1}, - router: fakeRouter, + router: r, track: senderTrack, simulcastSSRC: simulcastSSRC, } @@ -182,7 +176,7 @@ forLoop: s := &SimulcastSender{ ctx: context.Background(), enabled: atomicBool{1}, - router: fakeRouter, + router: r, track: senderTrack, simulcastSSRC: simulcastSSRC, lSSRC: fakeRecvTrack.SSRC(), @@ -227,20 +221,18 @@ func TestSimulcastSender_receiveRTCP(t *testing.T) { fakeReceiver := &ReceiverMock{ DeleteSenderFunc: func(_ string) { }, - } - - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return fakeReceiver - }, - SendRTCPFunc: func(pkts []rtcp.Packet) error { - for _, pkt := range pkts { - gotRTCP <- pkt + SendRTCPFunc: func(p []rtcp.Packet) { + for _, pp := range p { + gotRTCP <- pp } - return nil }, } + r := &receiverRouter{ + kind: SimulcastReceiver, + receivers: [3]Receiver{fakeReceiver, fakeReceiver, fakeReceiver}, + } + err = signalPair(sfu, remote) assert.NoError(t, err) @@ -283,7 +275,7 @@ forLoop: ctx: ctx, cancel: cancel, enabled: atomicBool{1}, - router: fakeRouter, + router: r, sender: s, track: senderTrack, lSSRC: recvSSRC, @@ -325,16 +317,16 @@ forLoop: func TestSimulcastSender_Close(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) closeCtr := 0 - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return nil - }, + + r := &receiverRouter{ + kind: SimulcastReceiver, + receivers: [3]Receiver{}, } type fields struct { ctx context.Context cancel context.CancelFunc - router Router + router *receiverRouter onCloseHandler func() } tests := []struct { @@ -347,7 +339,7 @@ func TestSimulcastSender_Close(t *testing.T) { fields: fields{ ctx: ctx, cancel: cancel, - router: fakeRouter, + router: r, onCloseHandler: nil, }, }, @@ -357,7 +349,7 @@ func TestSimulcastSender_Close(t *testing.T) { fields: fields{ ctx: ctx, cancel: cancel, - router: fakeRouter, + router: r, onCloseHandler: func() { closeCtr++ }, @@ -541,31 +533,29 @@ forLoop: } gotPli := make(chan struct{}, 1) - fakeRecv := &ReceiverMock{ + fakeReceiver := &ReceiverMock{ TrackFunc: func() *webrtc.Track { return senderTrack }, - } - - fakeRouter := &RouterMock{ - GetReceiverFunc: func(_ uint8) Receiver { - return fakeRecv - }, - SendRTCPFunc: func(pkts []rtcp.Packet) error { - for _, pkt := range pkts { - if _, ok := pkt.(*rtcp.PictureLossIndication); ok { + SendRTCPFunc: func(p []rtcp.Packet) { + for _, pp := range p { + if _, ok := pp.(*rtcp.PictureLossIndication); ok { gotPli <- struct{}{} } } - return nil }, } + r := &receiverRouter{ + kind: SimulcastReceiver, + receivers: [3]Receiver{fakeReceiver, fakeReceiver, fakeReceiver}, + } + simpleSdr := SimulcastSender{ ctx: context.Background(), enabled: atomicBool{1}, simulcastSSRC: 1234, - router: fakeRouter, + router: r, track: senderTrack, payload: senderTrack.PayloadType(), lSSRC: 1234, diff --git a/pkg/transport.go b/pkg/transport.go index d408fc489..2557d1203 100644 --- a/pkg/transport.go +++ b/pkg/transport.go @@ -4,8 +4,7 @@ package sfu // that media can be sent over type Transport interface { ID() string - GetRouter(string) Router - Routers() map[string]Router + GetRouter() Router AddSender(streamID string, sender Sender) GetSenders(streamID string) []Sender } diff --git a/pkg/twcc.go b/pkg/twcc.go new file mode 100644 index 000000000..a57401f8a --- /dev/null +++ b/pkg/twcc.go @@ -0,0 +1,288 @@ +package sfu + +import ( + "encoding/binary" + "math" + "math/rand" + "sort" + "sync" + + "github.com/gammazero/deque" + "github.com/pion/rtcp" +) + +const ( + baseSequenceNumberOffset = 8 + packetStatusCountOffset = 10 + referenceTimeOffset = 12 + + tccReportDelta = 1e8 + tccReportDeltaAfterMark = 50e6 +) + +type rtpExtInfo struct { + ExtTSN uint32 + Timestamp int64 +} + +type TransportWideCC struct { + sync.Mutex + rtcpCh chan []rtcp.Packet + + tccExtInfo []rtpExtInfo + tccLastReport int64 + tccCycles uint32 + tccLastExtSN uint32 + tccPktCtn uint8 + tccLastSn uint16 + lastExtInfo uint16 + mSSRC uint32 + sSSRC uint32 + + len uint16 + deltaLen uint16 + payload [50]byte + deltas [100]byte + chunk uint16 +} + +func newTransportWideCC(ch chan []rtcp.Packet) *TransportWideCC { + return &TransportWideCC{ + tccExtInfo: make([]rtpExtInfo, 0, 101), + rtcpCh: ch, + sSSRC: rand.Uint32(), + } +} + +func (t *TransportWideCC) push(sn uint16, timeNS int64, marker bool) { + t.Lock() + defer t.Unlock() + + if sn < 0x0fff && (t.tccLastSn&0xffff) > 0xf000 { + t.tccCycles += maxSN + } + t.tccExtInfo = append(t.tccExtInfo, rtpExtInfo{ + ExtTSN: t.tccCycles | uint32(sn), + Timestamp: timeNS / 1e3, + }) + t.tccLastSn = sn + delta := timeNS - t.tccLastReport + if delta >= tccReportDelta || len(t.tccExtInfo) > 100 || (marker && delta >= tccReportDeltaAfterMark) { + if pkt := t.buildTransportCCPacket(); pkt != nil { + t.rtcpCh <- []rtcp.Packet{pkt} + } + t.tccLastReport = timeNS + } +} + +func (t *TransportWideCC) buildTransportCCPacket() *rtcp.RawPacket { + if len(t.tccExtInfo) == 0 { + return nil + } + sort.Slice(t.tccExtInfo, func(i, j int) bool { + return t.tccExtInfo[i].ExtTSN < t.tccExtInfo[j].ExtTSN + }) + tccPkts := make([]rtpExtInfo, 0, int(float64(len(t.tccExtInfo))*1.2)) + for _, tccExtInfo := range t.tccExtInfo { + if tccExtInfo.ExtTSN < t.tccLastExtSN { + continue + } + if t.tccLastExtSN != 0 { + for j := t.tccLastExtSN + 1; j < tccExtInfo.ExtTSN; j++ { + tccPkts = append(tccPkts, rtpExtInfo{ExtTSN: j}) + } + } + t.tccLastExtSN = tccExtInfo.ExtTSN + tccPkts = append(tccPkts, tccExtInfo) + } + t.tccExtInfo = t.tccExtInfo[:0] + + firstRecv := false + same := true + timestamp := int64(0) + lastStatus := rtcp.TypeTCCPacketReceivedWithoutDelta + maxStatus := rtcp.TypeTCCPacketNotReceived + + var statusList deque.Deque + statusList.SetMinCapacity(3) + + for _, stat := range tccPkts { + status := rtcp.TypeTCCPacketNotReceived + if stat.Timestamp != 0 { + var delta int64 + if !firstRecv { + firstRecv = true + refTime := stat.Timestamp / 64e3 + timestamp = refTime * 64e3 + t.writeHeader( + uint16(tccPkts[0].ExtTSN), + uint16(len(tccPkts)), + uint32(refTime), + ) + t.tccPktCtn++ + } + + delta = (stat.Timestamp - timestamp) / 250 + if delta < 0 || delta > 255 { + status = rtcp.TypeTCCPacketReceivedLargeDelta + rDelta := int16(delta) + if int64(rDelta) != delta { + if rDelta > 0 { + rDelta = math.MaxInt16 + } else { + rDelta = math.MinInt16 + } + } + t.writeDelta(status, uint16(rDelta)) + } else { + status = rtcp.TypeTCCPacketReceivedSmallDelta + t.writeDelta(status, uint16(delta)) + } + timestamp = stat.Timestamp + } + + if same && status != lastStatus && lastStatus != rtcp.TypeTCCPacketReceivedWithoutDelta { + if statusList.Len() > 7 { + t.writeRunLengthChunk(lastStatus, uint16(statusList.Len())) + statusList.Clear() + lastStatus = rtcp.TypeTCCPacketReceivedWithoutDelta + maxStatus = rtcp.TypeTCCPacketNotReceived + same = true + } else { + same = false + } + } + statusList.PushBack(status) + if status > maxStatus { + maxStatus = status + } + lastStatus = status + + if !same && maxStatus == rtcp.TypeTCCPacketReceivedLargeDelta && statusList.Len() > 6 { + for i := 0; i < 7; i++ { + t.createStatusSymbolChunk(rtcp.TypeTCCSymbolSizeTwoBit, statusList.PopFront().(uint16), i) + } + t.writeStatusSymbolChunk(rtcp.TypeTCCSymbolSizeTwoBit) + lastStatus = rtcp.TypeTCCPacketReceivedWithoutDelta + maxStatus = rtcp.TypeTCCPacketNotReceived + same = true + + for i := 0; i < statusList.Len(); i++ { + status = statusList.At(i).(uint16) + if status > maxStatus { + maxStatus = status + } + if same && lastStatus != rtcp.TypeTCCPacketReceivedWithoutDelta && status != lastStatus { + same = false + } + lastStatus = status + } + } else if !same && statusList.Len() > 13 { + for i := 0; i < 14; i++ { + t.createStatusSymbolChunk(rtcp.TypeTCCSymbolSizeOneBit, statusList.PopFront().(uint16), i) + } + t.writeStatusSymbolChunk(rtcp.TypeTCCSymbolSizeOneBit) + lastStatus = rtcp.TypeTCCPacketReceivedWithoutDelta + maxStatus = rtcp.TypeTCCPacketNotReceived + same = true + } + } + + if statusList.Len() > 0 { + if same { + t.writeRunLengthChunk(lastStatus, uint16(statusList.Len())) + } else if maxStatus == rtcp.TypeTCCPacketReceivedLargeDelta { + for i := 0; i < statusList.Len(); i++ { + t.createStatusSymbolChunk(rtcp.TypeTCCSymbolSizeTwoBit, statusList.PopFront().(uint16), i) + } + t.writeStatusSymbolChunk(rtcp.TypeTCCSymbolSizeTwoBit) + } else { + for i := 0; i < statusList.Len(); i++ { + t.createStatusSymbolChunk(rtcp.TypeTCCSymbolSizeOneBit, statusList.PopFront().(uint16), i) + } + t.writeStatusSymbolChunk(rtcp.TypeTCCSymbolSizeOneBit) + } + } + + pLen := t.len + t.deltaLen + 4 + pad := pLen%4 != 0 + var padSize uint8 + for pLen%4 != 0 { + padSize++ + pLen++ + } + hdr := rtcp.Header{ + Padding: pad, + Length: (pLen / 4) - 1, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + } + hb, _ := hdr.Marshal() + pkt := make(rtcp.RawPacket, pLen) + copy(pkt, hb) + copy(pkt[4:], t.payload[:t.len]) + copy(pkt[4+t.len:], t.deltas[:t.deltaLen]) + if pad { + pkt[len(pkt)-1] = padSize + } + t.deltaLen = 0 + return &pkt +} + +func (t *TransportWideCC) writeHeader(bSN, packetCount uint16, refTime uint32) { + /* + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC of packet sender | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC of media source | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | base sequence number | packet status count | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | reference time | fb pkt. count | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + binary.BigEndian.PutUint32(t.payload[0:], t.sSSRC) + binary.BigEndian.PutUint32(t.payload[4:], t.mSSRC) + binary.BigEndian.PutUint16(t.payload[baseSequenceNumberOffset:], bSN) + binary.BigEndian.PutUint16(t.payload[packetStatusCountOffset:], packetCount) + binary.BigEndian.PutUint32(t.payload[referenceTimeOffset:], refTime<<8|uint32(t.tccPktCtn)) + t.len = 16 +} + +func (t *TransportWideCC) writeRunLengthChunk(symbol uint16, runLength uint16) { + /* + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |T| S | Run Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + binary.BigEndian.PutUint16(t.payload[t.len:], symbol<<13|runLength) + t.len += 2 +} + +func (t *TransportWideCC) createStatusSymbolChunk(symbolSize, symbol uint16, i int) { + /* + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |T|S| symbol list | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + numOfBits := symbolSize + 1 + t.chunk = setNBitsOfUint16(t.chunk, numOfBits, numOfBits*uint16(i)+2, symbol) +} + +func (t *TransportWideCC) writeStatusSymbolChunk(symbolSize uint16) { + t.chunk = setNBitsOfUint16(t.chunk, 1, 0, 1) + t.chunk = setNBitsOfUint16(t.chunk, 1, 1, symbolSize) + binary.BigEndian.PutUint16(t.payload[t.len:], t.chunk) + t.chunk = 0 + t.len += 2 +} + +func (t *TransportWideCC) writeDelta(deltaType, delta uint16) { + if deltaType == rtcp.TypeTCCPacketReceivedSmallDelta { + t.deltas[t.deltaLen] = byte(delta) + t.deltaLen++ + return + } + binary.BigEndian.PutUint16(t.deltas[t.deltaLen:], delta) + t.deltaLen += 2 +} diff --git a/pkg/twcc_test.go b/pkg/twcc_test.go new file mode 100644 index 000000000..b19a4885b --- /dev/null +++ b/pkg/twcc_test.go @@ -0,0 +1,326 @@ +package sfu + +import ( + "testing" + + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestTransportWideCC_writeRunLengthChunk(t1 *testing.T) { + type fields struct { + len uint16 + } + type args struct { + symbol uint16 + runLength uint16 + } + tests := []struct { + name string + fields fields + args args + wantErr bool + wantBytes []byte + }{ + { + name: "Must not return error", + + args: args{ + symbol: rtcp.TypeTCCPacketNotReceived, + runLength: 221, + }, + wantErr: false, + wantBytes: []byte{0, 0xdd}, + }, { + name: "Must set run length after padding", + fields: fields{ + len: 1, + }, + args: args{ + symbol: rtcp.TypeTCCPacketReceivedWithoutDelta, + runLength: 24, + }, + wantBytes: []byte{0, 0x60, 0x18}, + }, + } + for _, tt := range tests { + tt := tt + t1.Run(tt.name, func(t1 *testing.T) { + t := &TransportWideCC{ + len: tt.fields.len, + } + t.writeRunLengthChunk(tt.args.symbol, tt.args.runLength) + assert.Equal(t1, tt.wantBytes, t.payload[:t.len]) + }) + } +} + +func TestTransportWideCC_writeStatusSymbolChunk(t1 *testing.T) { + type fields struct { + len uint16 + } + type args struct { + symbolSize uint16 + symbolList []uint16 + } + tests := []struct { + name string + fields fields + args args + wantBytes []byte + }{ + { + name: "Must not return error", + args: args{ + symbolSize: rtcp.TypeTCCSymbolSizeOneBit, + symbolList: []uint16{rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived}, + }, + wantBytes: []byte{0x9F, 0x1C}, + }, + { + name: "Must set symbol chunk after padding", + fields: fields{ + len: 1, + }, + args: args{ + symbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + symbolList: []uint16{ + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedWithoutDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived}, + }, + wantBytes: []byte{0x0, 0xcd, 0x50}, + }, + } + for _, tt := range tests { + tt := tt + t1.Run(tt.name, func(t1 *testing.T) { + t := &TransportWideCC{ + len: tt.fields.len, + } + for i, v := range tt.args.symbolList { + t.createStatusSymbolChunk(tt.args.symbolSize, v, i) + } + t.writeStatusSymbolChunk(tt.args.symbolSize) + assert.Equal(t1, tt.wantBytes, t.payload[:t.len]) + }) + } +} + +func TestTransportWideCC_writeDelta(t1 *testing.T) { + a := -32768 + type fields struct { + deltaLen uint16 + } + type args struct { + deltaType uint16 + delta uint16 + } + tests := []struct { + name string + fields fields + args args + want []byte + }{ + { + name: "Must set correct small delta", + args: args{ + deltaType: rtcp.TypeTCCPacketReceivedSmallDelta, + delta: 255, + }, + want: []byte{0xff}, + }, + { + name: "Must set correct small delta with padding", + fields: fields{ + deltaLen: 1, + }, + args: args{ + deltaType: rtcp.TypeTCCPacketReceivedSmallDelta, + delta: 255, + }, + want: []byte{0, 0xff}, + }, + { + name: "Must set correct large delta", + args: args{ + deltaType: rtcp.TypeTCCPacketReceivedLargeDelta, + delta: 32767, + }, + want: []byte{0x7F, 0xFF}, + }, + { + name: "Must set correct large delta with padding", + fields: fields{ + deltaLen: 1, + }, + args: args{ + deltaType: rtcp.TypeTCCPacketReceivedLargeDelta, + delta: uint16(a), + }, + want: []byte{0, 0x80, 0x00}, + }, + } + for _, tt := range tests { + tt := tt + t1.Run(tt.name, func(t1 *testing.T) { + t := &TransportWideCC{ + deltaLen: tt.fields.deltaLen, + } + t.writeDelta(tt.args.deltaType, tt.args.delta) + assert.Equal(t1, tt.want, t.deltas[:t.deltaLen]) + assert.Equal(t1, tt.fields.deltaLen+tt.args.deltaType, t.deltaLen) + }) + } +} + +func TestTransportWideCC_writeHeader(t1 *testing.T) { + type fields struct { + tccPktCtn uint8 + sSSRC uint32 + mSSRC uint32 + } + type args struct { + bSN uint16 + packetCount uint16 + refTime uint32 + } + tests := []struct { + name string + fields fields + args args + want []byte + }{ + { + name: "Must construct correct header", + fields: fields{ + tccPktCtn: 23, + sSSRC: 4195875351, + mSSRC: 1124282272, + }, + args: args{ + bSN: 153, + packetCount: 1, + refTime: 4057090, + }, + want: []byte{ + 0xfa, 0x17, 0xfa, 0x17, + 0x43, 0x3, 0x2f, 0xa0, + 0x0, 0x99, 0x0, 0x1, + 0x3d, 0xe8, 0x2, 0x17}, + }, + } + for _, tt := range tests { + tt := tt + t1.Run(tt.name, func(t1 *testing.T) { + t := &TransportWideCC{ + tccPktCtn: tt.fields.tccPktCtn, + sSSRC: tt.fields.sSSRC, + mSSRC: tt.fields.mSSRC, + } + t.writeHeader(tt.args.bSN, tt.args.packetCount, tt.args.refTime) + assert.Equal(t1, tt.want, t.payload[0:16]) + }) + } +} + +func TestTccPacket(t1 *testing.T) { + want := []byte{ + 0xfa, 0x17, 0xfa, 0x17, + 0x43, 0x3, 0x2f, 0xa0, + 0x0, 0x99, 0x0, 0x1, + 0x3d, 0xe8, 0x2, 0x17, + 0x60, 0x18, 0x0, 0xdd, + 0x9F, 0x1C, 0xcd, 0x50, + } + + delta := []byte{ + 0xff, 0x80, 0xaa, + } + + symbol1 := []uint16{rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived} + symbol2 := []uint16{ + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedWithoutDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived} + + t := &TransportWideCC{ + tccPktCtn: 23, + sSSRC: 4195875351, + mSSRC: 1124282272, + } + t.writeHeader(153, 1, 4057090) + t.writeRunLengthChunk(rtcp.TypeTCCPacketReceivedWithoutDelta, 24) + t.writeRunLengthChunk(rtcp.TypeTCCPacketNotReceived, 221) + for i, v := range symbol1 { + t.createStatusSymbolChunk(rtcp.TypeTCCSymbolSizeOneBit, v, i) + } + t.writeStatusSymbolChunk(rtcp.TypeTCCSymbolSizeOneBit) + for i, v := range symbol2 { + t.createStatusSymbolChunk(rtcp.TypeTCCSymbolSizeTwoBit, v, i) + } + t.writeStatusSymbolChunk(rtcp.TypeTCCSymbolSizeTwoBit) + t.deltaLen = uint16(len(delta)) + assert.Equal(t1, want, t.payload[:24]) + + pLen := t.len + t.deltaLen + 4 + pad := pLen%4 != 0 + for pLen%4 != 0 { + pLen++ + } + hdr := rtcp.Header{ + Padding: pad, + Length: (pLen / 4) - 1, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + } + assert.Equal(t1, int(pLen), len(want)+3+4+1) + hb, _ := hdr.Marshal() + pkt := make([]byte, pLen) + copy(pkt, hb) + assert.Equal(t1, hb, pkt[:len(hb)]) + copy(pkt[4:], t.payload[:t.len]) + assert.Equal(t1, append(hb, t.payload[:t.len]...), pkt[:len(hb)+int(t.len)]) + copy(pkt[4+t.len:], delta[:t.deltaLen]) + assert.Equal(t1, delta, pkt[len(hb)+int(t.len):len(pkt)-1]) + var ss rtcp.TransportLayerCC + err := ss.Unmarshal(pkt) + assert.NoError(t1, err) + + assert.Equal(t1, hdr, ss.Header) + +} diff --git a/pkg/webrtctransport.go b/pkg/webrtctransport.go index 5bd1c60ee..c2ea8b874 100644 --- a/pkg/webrtctransport.go +++ b/pkg/webrtctransport.go @@ -2,14 +2,13 @@ package sfu import ( "context" + "strconv" "strings" "sync" - "time" "github.com/bep/debounce" "github.com/lucsky/cuid" log "github.com/pion/ion-log" - "github.com/pion/rtcp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" ) @@ -30,10 +29,10 @@ type WebRTCTransport struct { me MediaEngine mu sync.RWMutex candidates []webrtc.ICECandidateInit - mids map[string]Sender session *Session + mids map[string]Sender senders map[string][]Sender - routers map[string]Router + router Router onTrackHandler func(*webrtc.Track, *webrtc.RTPReceiver) } @@ -48,116 +47,44 @@ func NewWebRTCTransport(ctx context.Context, session *Session, me MediaEngine, c } ctx, cancel := context.WithCancel(ctx) + id := cuid.New() p := &WebRTCTransport{ - id: cuid.New(), + id: id, ctx: ctx, cancel: cancel, pc: pc, me: me, session: session, - routers: make(map[string]Router), + router: newRouter(pc, id, cfg.router), mids: make(map[string]Sender), senders: make(map[string][]Sender), } - - // Add transport to the session - session.AddTransport(p) - // Simulcast flag to add router to session - simulcastToSessionJoined := false - // Subscribe to existing transports defer func() { for _, t := range session.Transports() { if t.ID() == p.id { continue } - - for _, router := range t.Routers() { - err := router.AddSender(p) - // log.Infof("Init add router ssrc %d to %s", router.receivers[0].Track().SSRC(), p.id) - if err != nil { - log.Errorf("Subscribing to router err: %v", err) - continue - } + err := t.GetRouter().AddSender(p) + if err != nil { + log.Errorf("Subscribing to router err: %v", err) + continue } } }() + // Add transport to the session + session.AddTransport(p) pc.OnTrack(func(track *webrtc.Track, receiver *webrtc.RTPReceiver) { - log.Debugf("Peer %s got remote track id: %s ssrc: %d rid :%s streamID: %s", p.id, track.ID(), track.SSRC(), track.RID(), track.Label()) - recv := NewWebRTCReceiver(ctx, receiver, track, ReceiverConfig{ - RouterConfig: cfg.router, - tccExt: me.tCCExt, - }) - - if router, ok := p.routers[track.ID()]; !ok { - if track.RID() != "" { - router = newRouter(p, track.Label(), cfg.router, SimulcastRouter) - } else { - router = newRouter(p, track.Label(), cfg.router, SimpleRouter) - } - router.AddReceiver(recv) - // If track is simulcast and BestQualityFirst is true and current track is full resolution subscribe to router - if router.Kind() == SimulcastRouter && router.Config().Simulcast.BestQualityFirst && track.RID() == fullResolution { - simulcastToSessionJoined = true - p.session.AddRouter(router) - // If track is simulcast AND BestQualityFirst is false and track is full resolution - } else if router.Kind() == SimulcastRouter && !router.Config().Simulcast.BestQualityFirst && track.RID() == fullResolution { - // Wait one second to receive the quarter resolution, if not received it may be not supported or disabled - // and only half or full resolution was sent. - go func() { - select { - case <-time.After(time.Second): - if !simulcastToSessionJoined { - simulcastToSessionJoined = true - p.session.AddRouter(router) - return - } - } - }() - // If track is not simulcast OR is simulcast and BestQualityFirst is false and current track is not full - // resolution subscribe to router - } else if router.Kind() != SimulcastRouter || router.Kind() == SimulcastRouter && - !router.Config().Simulcast.BestQualityFirst && track.RID() != fullResolution { - simulcastToSessionJoined = true - p.session.AddRouter(router) - } - p.mu.Lock() - p.routers[recv.Track().ID()] = router - p.mu.Unlock() - log.Debugf("Created router %s %d", p.id, recv.Track().SSRC()) - } else { - if !simulcastToSessionJoined && - (router.Config().Simulcast.BestQualityFirst && track.RID() == fullResolution || - !router.Config().Simulcast.BestQualityFirst && track.RID() == quarterResolution) { - simulcastToSessionJoined = true - p.session.AddRouter(router) - } - router.AddReceiver(recv) - } - - recv.OnCloseHandler(func() { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.routers, track.ID()) - }) - - if track.Kind() == webrtc.RTPCodecTypeVideo { - recv.OnLostHandler(func(nack *rtcp.TransportLayerNack) { - log.Debugf("Writing nack to peer: %s, ssrc: %d, missing sn: %d, bitmap: %b", p.id, track.SSRC(), nack.Nacks[0].PacketID, nack.Nacks[0].LostPackets) - if err := p.pc.WriteRTCP([]rtcp.Packet{nack}); err != nil { - log.Errorf("write nack rtcp err: %v", err) - } - }) - } - + log.Debugf("Peer %s got remote track id: %s mediaSSRC: %d rid :%s streamID: %s", p.id, track.ID(), track.SSRC(), track.RID(), track.Label()) + p.router.AddReceiver(ctx, track, receiver) if p.onTrackHandler != nil { p.onTrackHandler(track, receiver) } }) pc.OnDataChannel(func(d *webrtc.DataChannel) { - log.Debugf("New DataChannel %s %d\n", d.Label(), d.ID()) + log.Debugf("New DataChannel %s %d", d.Label(), d.ID()) // Register text message handling if d.Label() == channelLabel { handleAPICommand(p, d) @@ -180,12 +107,11 @@ func NewWebRTCTransport(ctx context.Context, session *Session, me MediaEngine, c if err := p.Close(); err != nil { log.Errorf("webrtc transport close err: %v", err) } + p.router.Stop() } } }) - go p.sendRTCP() - return p, nil } @@ -195,26 +121,36 @@ func (p *WebRTCTransport) CreateOffer() (webrtc.SessionDescription, error) { if err != nil { return webrtc.SessionDescription{}, err } - parsed := sdp.SessionDescription{} if err := parsed.Unmarshal([]byte(offer.SDP)); err == nil { for _, md := range parsed.MediaDescriptions { - if mid, ok := md.Attribute(sdp.AttrKeyMID); ok { - if msid, ok := md.Attribute(sdp.AttrKeyMsid); ok { - split := strings.Split(msid, " ") - if len(split) != 2 { - log.Errorf("Invalid msid: %s", msid) - continue + if md.MediaName.Media != mediaNameAudio && md.MediaName.Media != mediaNameVideo { + continue + } + var msid, mid string + + for _, att := range md.Attributes { + switch att.Key { + case sdp.AttrKeyMID: + mid = att.Value + if len(msid) > 0 { + break } - - msid := split[0] - tid := split[1] - - // find sender for mid - for _, sender := range p.senders[msid] { - if sender.Track().ID() == tid { - p.mids[mid] = sender - } + case sdp.AttrKeyMsid: + msid = att.Value + if len(mid) > 0 { + break + } + } + } + if len(msid) > 0 && len(mid) > 0 { + split := strings.Split(msid, " ") + sid := split[0] + tid := split[1] + // find sender for mid + for _, sender := range p.senders[sid] { + if sender.Track().ID() == tid { + p.mids[mid] = sender } } } @@ -248,7 +184,12 @@ func (p *WebRTCTransport) CreateAnswer() (webrtc.SessionDescription, error) { // SetRemoteDescription sets the SessionDescription of the remote peer func (p *WebRTCTransport) SetRemoteDescription(desc webrtc.SessionDescription) error { - err := p.pc.SetRemoteDescription(desc) + pd, err := desc.Unmarshal() + if err != nil { + log.Errorf("SetRemoteDescription error: %v", err) + return err + } + err = p.pc.SetRemoteDescription(desc) if err != nil { log.Errorf("SetRemoteDescription error: %v", err) return err @@ -264,17 +205,40 @@ func (p *WebRTCTransport) SetRemoteDescription(desc webrtc.SessionDescription) e p.candidates = nil } - parsed := sdp.SessionDescription{} - if err := parsed.Unmarshal([]byte(desc.SDP)); err == nil { - for _, md := range parsed.MediaDescriptions { - if mid, ok := md.Attribute(sdp.AttrKeyMID); ok { - if p.mids[mid] != nil { - p.mids[mid].Start() - // remove mid mapping incase transceiver is reused later - p.mids[mid] = nil + for _, md := range pd.MediaDescriptions { + if md.MediaName.Media != mediaNameAudio && md.MediaName.Media != mediaNameVideo { + continue + } + var ( + ext int + id string + ) + + for _, att := range md.Attributes { + if att.Key == sdp.AttrKeyMID { + if p.mids[att.Value] != nil { + p.mids[att.Value].Start() + // remove mid mapping in case transceiver is reused later + p.mids[att.Value] = nil + } + } + + if att.Key == sdp.AttrKeyExtMap && strings.HasSuffix(att.Value, sdp.TransportCCURI) { + ext, _ = strconv.Atoi(att.Value[:1]) + if len(id) > 0 { + break + } + } + if att.Key == sdp.AttrKeyMsid { + v := strings.Split(att.Value, " ") + id = v[len(v)-1] + if ext != 0 { + break } } } + p.router.AddTWCCExt(id, ext) + } return nil @@ -338,18 +302,9 @@ func (p *WebRTCTransport) ID() string { return p.id } -// Routers returns routers for this peer -func (p *WebRTCTransport) Routers() map[string]Router { - p.mu.RLock() - defer p.mu.RUnlock() - return p.routers -} - -// GetRouter returns router with ssrc -func (p *WebRTCTransport) GetRouter(trackID string) Router { - p.mu.RLock() - defer p.mu.RUnlock() - return p.routers[trackID] +// GetRouter returns router with mediaSSRC +func (p *WebRTCTransport) GetRouter() Router { + return p.router } func (p *WebRTCTransport) AddSender(streamID string, sender Sender) { @@ -375,26 +330,3 @@ func (p *WebRTCTransport) Close() error { p.cancel() return p.pc.Close() } - -func (p *WebRTCTransport) sendRTCP() { - t := time.NewTicker(time.Second) - for { - select { - case <-t.C: - pkts := make([]rtcp.Packet, 0) - p.mu.RLock() - for _, r := range p.routers { - pkts = append(pkts, r.GetRTCP()...) - } - p.mu.RUnlock() - if len(pkts) > 0 { - if err := p.pc.WriteRTCP(pkts); err != nil { - log.Errorf("write rtcp err: %v", err) - } - } - case <-p.ctx.Done(): - t.Stop() - return - } - } -} diff --git a/pkg/webrtctransport_test.go b/pkg/webrtctransport_test.go index b1a0a8a62..6dd1b51b3 100644 --- a/pkg/webrtctransport_test.go +++ b/pkg/webrtctransport_test.go @@ -231,7 +231,7 @@ func TestWebRTCTransport_CreateOffer(t *testing.T) { func TestWebRTCTransport_GetRouter(t *testing.T) { type fields struct { - routers map[string]Router + router Router } type args struct { trackID string @@ -248,7 +248,7 @@ func TestWebRTCTransport_GetRouter(t *testing.T) { { name: "Must return router by ID", fields: fields{ - routers: map[string]Router{"test": router}, + router: router, }, args: args{ trackID: "test", @@ -260,9 +260,9 @@ func TestWebRTCTransport_GetRouter(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { p := &WebRTCTransport{ - routers: tt.fields.routers, + router: tt.fields.router, } - if got := p.GetRouter(tt.args.trackID); !reflect.DeepEqual(got, tt.want) { + if got := p.GetRouter(); !reflect.DeepEqual(got, tt.want) { t.Errorf("GetRouter() = %v, want %v", got, tt.want) } }) @@ -544,39 +544,6 @@ func TestWebRTCTransport_OnTrack(t *testing.T) { } } -func TestWebRTCTransport_Routers(t *testing.T) { - type fields struct { - routers map[string]Router - } - - routers := map[string]Router{"test": &router{}} - - tests := []struct { - name string - fields fields - want map[string]Router - }{ - { - name: "Must return current map of routers", - fields: fields{ - routers: routers, - }, - want: routers, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - p := &WebRTCTransport{ - routers: tt.fields.routers, - } - if got := p.Routers(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("Routers() = %v, want %v", got, tt.want) - } - }) - } -} - func TestWebRTCTransport_SetLocalDescription(t *testing.T) { me := webrtc.MediaEngine{} me.RegisterDefaultCodecs() @@ -672,6 +639,10 @@ func TestWebRTCTransport_SetRemoteDescription(t *testing.T) { t.Run(tt.name, func(t *testing.T) { p := &WebRTCTransport{ pc: tt.fields.pc, + router: &RouterMock{ + AddTWCCExtFunc: func(_ string, _ int) { + }, + }, } if err := p.SetRemoteDescription(tt.args.desc); (err != nil) != tt.wantErr { t.Errorf("SetRemoteDescription() error = %v, wantErr %v", err, tt.wantErr)