diff --git a/Cargo.lock b/Cargo.lock index 29841e78..a1eb8ddc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -95,7 +95,7 @@ dependencies = [ "nom", "num-traits", "rusticata-macros", - "thiserror", + "thiserror 1.0.58", "time", ] @@ -156,15 +156,24 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.7" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bitflags" -version = "1.3.2" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] [[package]] name = "bumpalo" @@ -190,6 +199,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clap" version = "4.5.4" @@ -248,9 +263,9 @@ dependencies = [ [[package]] name = "core-foundation" -version = "0.9.4" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" dependencies = [ "core-foundation-sys", "libc", @@ -258,9 +273,28 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] [[package]] name = "data-encoding" @@ -291,6 +325,16 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.4" @@ -406,6 +450,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.12" @@ -413,8 +467,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -490,9 +546,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.153" +version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "linked-hash-map" @@ -732,9 +788,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "pem" -version = "3.0.3" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" +checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3" dependencies = [ "base64", "serde", @@ -781,18 +837,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.79" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" dependencies = [ "unicode-ident", ] [[package]] name = "quinn" -version = "0.10.2" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cc2c5017e4b43d5995dcea317bc46c1e09404c0a9664d2908f7f02dfe943d75" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", @@ -800,40 +856,44 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "thiserror", + "socket2", + "thiserror 2.0.12", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.10.6" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "141bf7dfde2fbc246bfd3fe12f2455aa24b0fbd9af535d8c86c7bd1381ff2b1a" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", - "ring 0.16.20", + "ring", "rustc-hash", "rustls", - "rustls-native-certs", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.12", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.4.1" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "055b4e778e8feb9f93c4e439f71dc2156ef13360b432b799e179a8c4cdf0b1d7" +checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" dependencies = [ - "bytes", + "cfg_aliases", "libc", + "once_cell", "socket2", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -877,12 +937,13 @@ dependencies = [ [[package]] name = "rcgen" -version = "0.12.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" dependencies = [ "pem", - "ring 0.17.8", + "ring", + "rustls-pki-types", "time", "yasna", ] @@ -931,21 +992,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" -[[package]] -name = "ring" -version = "0.16.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" -dependencies = [ - "cc", - "libc", - "once_cell", - "spin 0.5.2", - "untrusted 0.7.1", - "web-sys", - "winapi", -] - [[package]] name = "ring" version = "0.17.8" @@ -956,8 +1002,8 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin 0.9.8", - "untrusted 0.9.0", + "spin", + "untrusted", "windows-sys 0.52.0", ] @@ -969,9 +1015,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustc-hash" -version = "1.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rusticata-macros" @@ -984,61 +1030,57 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.10" +version = "0.23.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ - "log", - "ring 0.17.8", + "once_cell", + "ring", + "rustls-pki-types", "rustls-webpki", - "sct", + "subtle", + "zeroize", ] [[package]] name = "rustls-native-certs" -version = "0.6.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 1.0.4", + "rustls-pki-types", "schannel", "security-framework", ] [[package]] name = "rustls-pemfile" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" -dependencies = [ - "base64", -] - -[[package]] -name = "rustls-pemfile" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f48172685e6ff52a556baa527774f61fcaa884f59daf3375c62a3f1cd2549dab" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64", "rustls-pki-types", ] [[package]] name = "rustls-pki-types" -version = "1.4.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" -version = "0.101.7" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", + "ring", + "rustls-pki-types", + "untrusted", ] [[package]] @@ -1050,21 +1092,11 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "sct" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" -dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", -] - [[package]] name = "security-framework" -version = "2.10.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ "bitflags", "core-foundation", @@ -1075,9 +1107,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.10.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -1114,6 +1146,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1148,12 +1191,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" @@ -1166,11 +1203,17 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" -version = "2.0.58" +version = "2.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" +checksum = "e02e925281e18ffd9d640e234264753c43edc62d64b2d4cf898f1bc5e75f3fc2" dependencies = [ "proc-macro2", "quote", @@ -1194,7 +1237,16 @@ version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.58", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", ] [[package]] @@ -1208,6 +1260,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -1271,6 +1334,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" dependencies = [ "backtrace", + "bytes", "libc", "mio", "num_cpus", @@ -1314,7 +1378,6 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -1379,6 +1442,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -1400,12 +1469,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "untrusted" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" - [[package]] name = "untrusted" version = "0.9.0" @@ -1435,6 +1498,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1517,6 +1586,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1682,20 +1761,21 @@ dependencies = [ [[package]] name = "wtransport" -version = "0.1.12" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "358eb4bb09681342ab5284da0a4a2443183509874e8417e85391850786d08f33" +checksum = "93a724f65db90b6a1ffa92ea4966cf03cb7e2bcd3ef7135b84dfe4339640d1b9" dependencies = [ "bytes", + "pem", "quinn", "rcgen", - "ring 0.17.8", "rustls", "rustls-native-certs", - "rustls-pemfile 2.1.1", + "rustls-pemfile", "rustls-pki-types", + "sha2", "socket2", - "thiserror", + "thiserror 1.0.58", "time", "tokio", "tracing", @@ -1706,13 +1786,13 @@ dependencies = [ [[package]] name = "wtransport-proto" -version = "0.1.12" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "394b6588219ddc2ad6bd4ffd2a94d5ad6858ecf02acdd7d39680ba8933103eb9" +checksum = "14e4882c24a62f15024609b3688e6e29a4c2129634d27debc849ccd3f9b9690b" dependencies = [ "httlib-huffman", "octets", - "thiserror", + "thiserror 1.0.58", "url", ] @@ -1729,7 +1809,7 @@ dependencies = [ "nom", "oid-registry", "rusticata-macros", - "thiserror", + "thiserror 1.0.58", "time", ] @@ -1741,3 +1821,9 @@ checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" dependencies = [ "time", ] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/README.md b/README.md index da94640f..35c33bd4 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ Supported version: draft-ietf-moq-transport-06 - [x] ANNOUNCE_ERROR - [ ] ANNOUNCE_CANCEL - [ ] TRACK_STATUS_REQUEST - - [x] SUBSCRIBE_NAMESPACE - - [ ] UNSUBSCRIBE_NAMESPACE + - [x] SUBSCRIBE_ANNOUNCES + - [ ] UNSUBSCRIBE_ANNOUNCES - [x] SUBSCRIBE_OK - [x] SUBSCRIBE_ERROR - [ ] SUBSCRIBE_DONE @@ -26,11 +26,14 @@ Supported version: draft-ietf-moq-transport-06 - [x] ANNOUNCE - [ ] UNANNOUNCE - [ ] TRACK_STATUS - - [x] SUBSCRIBE_NAMESPACE_OK - - [x] SUBSCRIBE_NAMESPACE_ERROR + - [x] SUBSCRIBE_ANNOUNCES_OK + - [x] SUBSCRIBE_ANNOUNCES_ERROR + - [ ] FETCH + - [ ] FETCH_OK + - [ ] FETCH_ERROR + - [ ] FETCH_CANCEL - [x] Data Streams - [x] Datagram - - [x] Track Stream - [x] Subgroup Stream - [ ] Features - [x] Manage Publisher / Subscriber diff --git a/docs/flow.drawio.svg b/docs/flow.drawio.svg index 74c44ec5..46273347 100644 --- a/docs/flow.drawio.svg +++ b/docs/flow.drawio.svg @@ -92,7 +92,7 @@
- send_stream + control_message
_dispatcher
@@ -102,7 +102,7 @@
- send_stream... + control_message... @@ -435,13 +435,13 @@
- SendStreamDispatcherへの登録処理 + ControlMessageDispatcherへの登録処理
- SendStreamDispatcher... + ControlMessageDispatcher... diff --git a/js/examples/media/index.html b/js/examples/media/index.html new file mode 100644 index 00000000..84c01ae7 --- /dev/null +++ b/js/examples/media/index.html @@ -0,0 +1,15 @@ + + + + + + Media Examples + + +

Media Examples

+
+ Publisher + Subscriber +
+ + diff --git a/js/examples/media/publisher/audioEncoder.ts b/js/examples/media/publisher/audioEncoder.ts new file mode 100644 index 00000000..8784fbf7 --- /dev/null +++ b/js/examples/media/publisher/audioEncoder.ts @@ -0,0 +1,47 @@ +let audioEncoder: AudioEncoder | undefined +const AUDIO_ENCODER_CONFIG = { + codec: 'opus', + sampleRate: 48000, // Opusの推奨サンプルレート + numberOfChannels: 1, // モノラル + bitrate: 64000 // 64kbpsのビットレート +} + +function sendAudioChunkMessage(chunk: EncodedAudioChunk, metadata: EncodedAudioChunkMetadata | undefined) { + self.postMessage({ chunk, metadata }) +} + +async function initializeAudioEncoder() { + const init: AudioEncoderInit = { + output: sendAudioChunkMessage, + error: (e: any) => { + console.log(e.message) + } + } + + const encoder = new AudioEncoder(init) + encoder.configure(AUDIO_ENCODER_CONFIG) + return encoder +} + +async function startAudioEncode(audioReadableStream: ReadableStream) { + if (!audioEncoder) { + audioEncoder = await initializeAudioEncoder() + } + const audioReader = audioReadableStream.getReader() + while (true) { + const audioResult = await audioReader.read() + if (audioResult.done) break + const audio = audioResult.value + audioEncoder.encode(audio) + audio.close() + } +} + +self.onmessage = async (event) => { + const audioReadableStream: ReadableStream = event.data.audioStream + if (!audioReadableStream) { + console.error('MediaStreamTrack が渡されていません') + return + } + await startAudioEncode(audioReadableStream) +} diff --git a/js/examples/media/publisher/const.ts b/js/examples/media/publisher/const.ts new file mode 100644 index 00000000..73c6e2cb --- /dev/null +++ b/js/examples/media/publisher/const.ts @@ -0,0 +1,2 @@ +export const KEYFRAME_INTERVAL = 300 +export const AUTH_INFO = 'secret' diff --git a/js/examples/media/publisher/index.html b/js/examples/media/publisher/index.html new file mode 100644 index 00000000..36403597 --- /dev/null +++ b/js/examples/media/publisher/index.html @@ -0,0 +1,95 @@ + + + + + + MOQT Media Publisher Test + + +

Publisher

+
+
+
+

Connection

+ + + + +
+
+
+

Forwarding Preference

+ + +
+ *You must select at first + +

CLIENT_SETUP

+ Max Subscribe ID: +
+ + +

ANNOUNCE

+ +
+ +

OBJECT

+
+

Header

+
+ + +
+ + +
+
+

Object

+
+ +
+
+ +
+ +
+ +
+
+
+ +
+
+ +
+
+
+ + + + + diff --git a/js/examples/media/publisher/main.ts b/js/examples/media/publisher/main.ts new file mode 100644 index 00000000..efec0562 --- /dev/null +++ b/js/examples/media/publisher/main.ts @@ -0,0 +1,245 @@ +import init, { MOQTClient } from '../../../pkg/moqt_client_sample' +import { AUTH_INFO, KEYFRAME_INTERVAL } from './const' +import { sendVideoObjectMessage, sendAudioObjectMessage } from './sender' +import { getFormElement } from './utils' + +let mediaStream: MediaStream | null = null +function setUpStartGetUserMediaButton() { + const startGetUserMediaBtn = document.getElementById('startGetUserMediaBtn') as HTMLButtonElement + startGetUserMediaBtn.addEventListener('click', async () => { + const constraints = { + audio: true, + video: true + } + mediaStream = await navigator.mediaDevices.getUserMedia(constraints) + const video = document.getElementById('video') as HTMLVideoElement + video.srcObject = mediaStream + }) +} + +const LatestMediaTrackInfo: { + video: { + objectId: bigint + groupId: bigint + subgroups: {} + } + audio: { + objectId: bigint + groupId: bigint + subgroups: { + [key: number]: { + isSendedSubgroupHeader: boolean + } + } + } +} = { + video: { + objectId: 0n, + groupId: 0n, + subgroups: { + 0: {}, + 1: {}, + 2: {} + } + }, + audio: { + objectId: 0n, + groupId: 0n, + subgroups: { + 0: { isSendedSubgroupHeader: false } + } + } +} + +const videoEncoderWorker = new Worker('videoEncoder.ts') +async function handleVideoChunkMessage( + chunk: EncodedVideoChunk, + metadata: EncodedVideoChunkMetadata | undefined, + client: MOQTClient +) { + const form = getFormElement() + const trackAlias = form['video-object-track-alias'].value + const publisherPriority = form['video-publisher-priority'].value + + // Increment the groupId and reset the objectId at the timing of the keyframe + // Then, resend the SubgroupStreamHeader + if (chunk.type === 'key') { + LatestMediaTrackInfo['video'].groupId++ + LatestMediaTrackInfo['video'].objectId = BigInt(0) + + const subgroupKeys = Object.keys(LatestMediaTrackInfo['video'].subgroups).map(BigInt) + for (const subgroup of subgroupKeys) { + await client.sendSubgroupStreamHeaderMessage( + BigInt(trackAlias), + LatestMediaTrackInfo['video'].groupId, + // @ts-ignore - The SVC property is not defined in the standard but actually exists + subgroup, + publisherPriority + ) + console.log('send subgroup stream header') + } + } + + sendVideoObjectMessage( + trackAlias, + LatestMediaTrackInfo['video'].groupId, + // @ts-ignore - The SVC property is not defined in the standard but actually exists + BigInt(metadata?.svc.temporalLayerId), // = subgroupId + LatestMediaTrackInfo['video'].objectId, + chunk, + metadata, + client + ) + LatestMediaTrackInfo['video'].objectId++ +} + +const audioEncoderWorker = new Worker('audioEncoder.ts') +async function handleAudioChunkMessage( + chunk: EncodedAudioChunk, + metadata: EncodedAudioChunkMetadata | undefined, + client: MOQTClient +) { + const form = getFormElement() + const trackAlias = form['audio-object-track-alias'].value + const publisherPriority = form['audio-publisher-priority'].value + const subgroupId = 0 + + if (!LatestMediaTrackInfo['audio']['subgroups'][subgroupId].isSendedSubgroupHeader) { + await client.sendSubgroupStreamHeaderMessage( + BigInt(trackAlias), + LatestMediaTrackInfo['audio'].groupId, + BigInt(subgroupId), + publisherPriority + ) + console.log('send subgroup stream header') + LatestMediaTrackInfo['audio']['subgroups'][subgroupId].isSendedSubgroupHeader = true + } + + sendAudioObjectMessage( + trackAlias, + LatestMediaTrackInfo['audio'].groupId, + BigInt(subgroupId), + LatestMediaTrackInfo['audio'].objectId, + chunk, + metadata, + client + ) + LatestMediaTrackInfo['audio'].objectId++ +} + +function setupClientCallbacks(client: MOQTClient): void { + client.onSetup(async (serverSetup: any) => { + console.log({ serverSetup }) + }) + + client.onAnnounce(async (announceMessage: any) => { + console.log({ announceMessage }) + const announcedNamespace = announceMessage.track_namespace + + await client.sendAnnounceOkMessage(announcedNamespace) + }) + + client.onAnnounceResponce(async (announceResponceMessage: any) => { + console.log({ announceResponceMessage }) + }) + + client.onSubscribe(async (subscribeMessage: any, isSuccess: any, code: any) => { + console.log({ subscribeMessage }) + const form = getFormElement() + const receivedSubscribeId = BigInt(subscribeMessage.subscribe_id) + const receivedTrackAlias = BigInt(subscribeMessage.track_alias) + console.log('subscribeId', receivedSubscribeId, 'trackAlias', receivedTrackAlias) + + if (isSuccess) { + const expire = 0n + const forwardingPreference = (Array.from(form['forwarding-preference']) as HTMLInputElement[]).filter( + (elem) => elem.checked + )[0].value + await client.sendSubscribeOkMessage(receivedSubscribeId, expire, AUTH_INFO, forwardingPreference) + } else { + const reasonPhrase = 'subscribe error' + await client.sendSubscribeErrorMessage(subscribeMessage.subscribe_id, code, reasonPhrase) + } + }) +} + +function sendSetupButtonClickHandler(client: MOQTClient): void { + const sendSetupBtn = document.getElementById('sendSetupBtn') as HTMLButtonElement + sendSetupBtn.addEventListener('click', async () => { + const form = getFormElement() + + const versions = new BigUint64Array('0xff00000A'.split(',').map(BigInt)) + const maxSubscribeId = BigInt(form['max-subscribe-id'].value) + + await client.sendSetupMessage(versions, maxSubscribeId) + }) +} + +function sendAnnounceButtonClickHandler(client: MOQTClient): void { + const sendAnnounceBtn = document.getElementById('sendAnnounceBtn') as HTMLButtonElement + sendAnnounceBtn.addEventListener('click', async () => { + const form = getFormElement() + const trackNamespace = form['announce-track-namespace'].value.split('/') + + await client.sendAnnounceMessage(trackNamespace, AUTH_INFO) + }) +} + +function sendSubgroupObjectButtonClickHandler(client: MOQTClient): void { + const sendSubgroupObjectBtn = document.getElementById('sendSubgroupObjectBtn') as HTMLButtonElement + sendSubgroupObjectBtn.addEventListener('click', async () => { + if (mediaStream == null) { + console.error('mediaStream is null') + return + } + videoEncoderWorker.onmessage = async (e: MessageEvent) => { + const { chunk, metadata } = e.data as { + chunk: EncodedVideoChunk + metadata: EncodedVideoChunkMetadata | undefined + } + console.log(chunk, metadata) + handleVideoChunkMessage(chunk, metadata, client) + } + audioEncoderWorker.onmessage = async (e: MessageEvent) => { + const { chunk, metadata } = e.data as { + chunk: EncodedAudioChunk + metadata: EncodedAudioChunkMetadata | undefined + } + handleAudioChunkMessage(chunk, metadata, client) + } + + const [videoTrack] = mediaStream.getVideoTracks() + const videoProcessor = new MediaStreamTrackProcessor({ track: videoTrack }) + const videoStream = videoProcessor.readable + videoEncoderWorker.postMessage({ + type: 'keyframeInterval', + keyframeInterval: KEYFRAME_INTERVAL + }) + videoEncoderWorker.postMessage({ type: 'videoStream', videoStream: videoStream }, [videoStream]) + const [audioTrack] = mediaStream.getAudioTracks() + const audioProcessor = new MediaStreamTrackProcessor({ track: audioTrack }) + const audioStream = audioProcessor.readable + audioEncoderWorker.postMessage({ type: 'audioStream', audioStream: audioStream }, [audioStream]) + }) +} + +function setupButtonClickHandler(client: MOQTClient): void { + sendSetupButtonClickHandler(client) + sendAnnounceButtonClickHandler(client) + sendSubgroupObjectButtonClickHandler(client) +} + +init().then(async () => { + setUpStartGetUserMediaButton() + + const connectBtn = document.getElementById('connectBtn') as HTMLButtonElement + connectBtn.addEventListener('click', async () => { + const form = getFormElement() + const url = form.url.value + const client = new MOQTClient(url) + setupClientCallbacks(client) + setupButtonClickHandler(client) + + await client.start() + }) +}) diff --git a/js/examples/media/publisher/sender.ts b/js/examples/media/publisher/sender.ts new file mode 100644 index 00000000..9fd4bafb --- /dev/null +++ b/js/examples/media/publisher/sender.ts @@ -0,0 +1,87 @@ +import { MOQTClient } from '../../../pkg/moqt_client_sample' +import { KEYFRAME_INTERVAL } from './const' + +export async function sendVideoObjectMessage( + trackAlias: bigint, + groupId: bigint, + subgroupId: bigint, + objectId: bigint, + chunk: EncodedVideoChunk, + metadata: EncodedVideoChunkMetadata | undefined, + client: MOQTClient +) { + // `EncodedVideoChunk` のデータを Uint8Array に変換 + const chunkArray = new Uint8Array(chunk.byteLength) + chunk.copyTo(chunkArray) + + const chunkData = { + type: chunk.type, + timestamp: chunk.timestamp, + duration: chunk.duration, + byteLength: chunk.byteLength, + data: Array.from(chunkArray), + decoderConfig: { + codec: metadata?.decoderConfig?.codec, + codedHeight: metadata?.decoderConfig?.codedHeight, + codedWidth: metadata?.decoderConfig?.codedWidth, + colorSpace: metadata?.decoderConfig?.colorSpace, + description: metadata?.decoderConfig?.description, + displayAspectHeight: metadata?.decoderConfig?.displayAspectHeight, + displayAspectWidth: metadata?.decoderConfig?.displayAspectWidth, + hardwareAcceleration: metadata?.decoderConfig?.hardwareAcceleration, + optimizeForLatency: metadata?.decoderConfig?.optimizeForLatency + }, + temporalLayer: metadata?.temporalLayerId + } + + const encoder = new TextEncoder() + const jsonString = JSON.stringify({ chunk: chunkData }) + const objectPayload = encoder.encode(jsonString) + + await client.sendSubgroupStreamObject(BigInt(trackAlias), groupId, subgroupId, objectId, undefined, objectPayload) + // If this object is end of group, send the ObjectStatus=EndOfGroupMessage. + // And delete unnecessary streams. + if (objectId === BigInt(KEYFRAME_INTERVAL - 1)) { + await client.sendSubgroupStreamObject( + BigInt(trackAlias), + groupId, + subgroupId, + BigInt(KEYFRAME_INTERVAL), + 3, // 0x3: EndOfGroup + Uint8Array.from([]) + ) + console.log('send Object(ObjectStatus=EndOfGroup)') + } +} + +export async function sendAudioObjectMessage( + trackAlias: bigint, + groupId: bigint, + subgroupId: bigint, + objectId: bigint, + chunk: EncodedAudioChunk, + metadata: EncodedAudioChunkMetadata | undefined, + client: MOQTClient +) { + // `EncodedAudioChunk` のデータを Uint8Array に変換 + const chunkArray = new Uint8Array(chunk.byteLength) + chunk.copyTo(chunkArray) + + const chunkData = { + type: chunk.type, + timestamp: chunk.timestamp, + duration: chunk.duration, + byteLength: chunk.byteLength, + data: Array.from(chunkArray), + decoderConfig: { + codec: metadata?.decoderConfig?.codec, + numberOfChannels: metadata?.decoderConfig?.numberOfChannels, + sampleRate: metadata?.decoderConfig?.sampleRate + } + } + + const encoder = new TextEncoder() + const jsonString = JSON.stringify({ chunk: chunkData }) + const objectPayload = encoder.encode(jsonString) + await client.sendSubgroupStreamObject(BigInt(trackAlias), groupId, subgroupId, objectId, undefined, objectPayload) +} diff --git a/js/examples/media/publisher/utils.ts b/js/examples/media/publisher/utils.ts new file mode 100644 index 00000000..b066e725 --- /dev/null +++ b/js/examples/media/publisher/utils.ts @@ -0,0 +1,3 @@ +export const getFormElement = (): HTMLFormElement => { + return document.getElementById('form') as HTMLFormElement +} diff --git a/js/examples/media/publisher/videoEncoder.ts b/js/examples/media/publisher/videoEncoder.ts new file mode 100644 index 00000000..09ef48b1 --- /dev/null +++ b/js/examples/media/publisher/videoEncoder.ts @@ -0,0 +1,64 @@ +let videoEncoder: VideoEncoder | undefined +let keyframeInterval: number +const VIDEO_ENCODER_CONFIG = { + codec: 'av01.0.04M.08', + width: 640, + height: 480, + bitrate: 2_000_000, // 2 Mbps + scalabilityMode: 'L1T3', + framerate: 30 +} + +function sendVideoChunkMessage(chunk: EncodedVideoChunk, metadata: EncodedVideoChunkMetadata | undefined) { + self.postMessage({ chunk, metadata }) +} + +async function initializeVideoEncoder() { + const init: VideoEncoderInit = { + output: sendVideoChunkMessage, + error: (e: any) => { + console.log(e.message) + } + } + + const encoder = new VideoEncoder(init) + encoder.configure(VIDEO_ENCODER_CONFIG) + return encoder +} + +async function startVideoEncode(videoReadableStream: ReadableStream) { + let frameCounter = 0 + if (!videoEncoder) { + videoEncoder = await initializeVideoEncoder() + } + const videoReader = videoReadableStream.getReader() + while (true) { + const videoResult = await videoReader.read() + if (videoResult.done) break + const videoFrame = videoResult.value + + // Too many frames in flight, encoder is overwhelmed. let's drop this frame. + if (videoEncoder.encodeQueueSize > 2) { + console.error('videoEncoder.encodeQueueSize > 2', videoEncoder.encodeQueueSize) + videoFrame.close() + } else { + const keyFrame = frameCounter % keyframeInterval == 0 + videoEncoder.encode(videoFrame, { keyFrame }) + frameCounter++ + videoFrame.close() + } + } +} + +self.onmessage = async (event) => { + if (event.data.type === 'keyframeInterval') { + keyframeInterval = event.data.keyframeInterval + } else if (event.data.type === 'videoStream') { + const videoReadableStream: ReadableStream = event.data.videoStream + if (!videoReadableStream) { + console.error('MediaStreamTrack が渡されていません') + return + } + await startVideoEncode(videoReadableStream) + } +} diff --git a/js/examples/media/subscriber/audioDecoder.ts b/js/examples/media/subscriber/audioDecoder.ts new file mode 100644 index 00000000..4831efa7 --- /dev/null +++ b/js/examples/media/subscriber/audioDecoder.ts @@ -0,0 +1,59 @@ +function sendAudioDataMessage(audioData: AudioData): void { + self.postMessage({ audioData }) + audioData.close() +} + +let audioDecoder: AudioDecoder | undefined +async function initializeAudioDecoder() { + const init: AudioDecoderInit = { + output: sendAudioDataMessage, + error: (e: any) => { + console.log(e.message) + } + } + const config = { + codec: 'opus', + sampleRate: 48000, // Opusの推奨サンプルレート + numberOfChannels: 1, // モノラル + bitrate: 64000 // 64kbpsのビットレート + } + const decoder = new AudioDecoder(init) + decoder.configure(config) + return decoder +} + +namespace AudioDecoder { + export type SubgroupStreamObject = { + objectId: number + objectPayloadLength: number + objectPayload: Uint8Array + objectStatus: any + } +} + +self.onmessage = async (event) => { + if (!audioDecoder) { + audioDecoder = await initializeAudioDecoder() + } + + const subgroupStreamObject: AudioDecoder.SubgroupStreamObject = { + objectId: event.data.subgroupStreamObject.object_id, + objectPayloadLength: event.data.subgroupStreamObject.object_payload_length, + objectPayload: event.data.subgroupStreamObject.object_payload, + objectStatus: event.data.subgroupStreamObject.object_status + } + // Rustから渡された時点ではUint8ArrayではなくArrayなので変換が必要 + const chunkArray = new Uint8Array(subgroupStreamObject.objectPayload) + const decoder = new TextDecoder() + const jsonString = decoder.decode(chunkArray) + const objectPayload = JSON.parse(jsonString) + + const encodedAudioChunk = new EncodedAudioChunk({ + type: objectPayload.chunk.type, + timestamp: objectPayload.chunk.timestamp, + duration: objectPayload.chunk.duration, + data: new Uint8Array(objectPayload.chunk.data) + }) + + await audioDecoder.decode(encodedAudioChunk) +} diff --git a/js/examples/media/subscriber/const.ts b/js/examples/media/subscriber/const.ts new file mode 100644 index 00000000..3cec97c7 --- /dev/null +++ b/js/examples/media/subscriber/const.ts @@ -0,0 +1 @@ +export const AUTH_INFO = 'secret' diff --git a/js/examples/media/subscriber/index.html b/js/examples/media/subscriber/index.html new file mode 100644 index 00000000..ebe8ef75 --- /dev/null +++ b/js/examples/media/subscriber/index.html @@ -0,0 +1,54 @@ + + + + + + MOQT Media Subscriber Test + + +

Subscriber

+
+
+
+

Connection

+ + + + +
+
+
+

Forwarding Preference

+ + +
+ +

CLIENT_SETUP

+ Max Subscribe ID: +
+ + + +

SUBSCRIBE

+ + +
+ + +
+ + +
+ + + + + diff --git a/js/examples/media/subscriber/main.ts b/js/examples/media/subscriber/main.ts new file mode 100644 index 00000000..290954a3 --- /dev/null +++ b/js/examples/media/subscriber/main.ts @@ -0,0 +1,132 @@ +import init, { MOQTClient } from '../../../pkg/moqt_client_sample' +import { AUTH_INFO } from './const' +import { getFormElement } from './utils' + +function setupClientCallbacks(client: MOQTClient) { + client.onSetup(async (serverSetup: any) => { + console.log({ serverSetup }) + }) +} + +function sendSetupButtonClickHandler(client: MOQTClient) { + const sendSetupBtn = document.getElementById('sendSetupBtn') as HTMLButtonElement + sendSetupBtn.addEventListener('click', async () => { + const form = getFormElement() + + const versions = new BigUint64Array('0xff00000A'.split(',').map(BigInt)) + const maxSubscribeId = BigInt(form['max-subscribe-id'].value) + + await client.sendSetupMessage(versions, maxSubscribeId) + }) +} + +function sendSubscribeButtonClickHandler(client: MOQTClient) { + const sendSubscribeBtn = document.getElementById('sendSubscribeBtn') as HTMLButtonElement + sendSubscribeBtn.addEventListener('click', async () => { + const form = getFormElement() + const trackNamespace = form['subscribe-track-namespace'].value.split('/') + setupClientObjectCallbacks(client, 'video', Number(0)) + await client.sendSubscribeMessage( + BigInt(0), + BigInt(0), + trackNamespace, + 'video', + 0, // subscriberPriority + 0, // groupOrder + 1, // Latest Group + BigInt(0), // startGroup + BigInt(0), // startObject + BigInt(10000), // endGroup + AUTH_INFO + ) + + setupClientObjectCallbacks(client, 'audio', Number(1)) + await client.sendSubscribeMessage( + BigInt(1), + BigInt(1), + trackNamespace, + 'audio', + 0, // subscriberPriority + 0, // groupOrder + 1, // Latest Group + BigInt(0), // startGroup + BigInt(0), // startObject + BigInt(10000), // endGroup + AUTH_INFO + ) + }) +} + +const audioDecoderWorker = new Worker('audioDecoder.ts') +function setupAudioDecoderWorker() { + const audioGenerator = new MediaStreamTrackGenerator({ kind: 'audio' }) + const audioWriter = audioGenerator.writable.getWriter() + const audioStream = new MediaStream([audioGenerator]) + const audioElement = document.getElementById('audio') as HTMLFormElement + audioElement.srcObject = audioStream + audioDecoderWorker.onmessage = async (e: MessageEvent) => { + const audioData = e.data.audioData + await audioWriter.write(audioData) + await audioElement.play() + } +} +const videoDecoderWorker = new Worker('videoDecoder.ts') +function setupVideoDecoderWorker() { + const videoGenerator = new MediaStreamTrackGenerator({ kind: 'video' }) + const videoWriter = videoGenerator.writable.getWriter() + const videoStream = new MediaStream([videoGenerator]) + const videoElement = document.getElementById('video') as HTMLFormElement + videoElement.srcObject = videoStream + videoDecoderWorker.onmessage = async (e: MessageEvent) => { + const videoFrame = e.data.frame + await videoWriter.write(videoFrame) + videoFrame.close() + await videoElement.play() + } +} + +function setupClientObjectCallbacks(client: MOQTClient, type: 'video' | 'audio', trackAlias: number) { + client.onSubgroupStreamHeader(async (subgroupStreamHeader: any) => { + console.log({ subgroupStreamHeader }) + }) + + if (type === 'audio') { + setupAudioDecoderWorker() + } else { + setupVideoDecoderWorker() + } + client.onSubgroupStreamObject(BigInt(trackAlias), async (subgroupStreamObject: any) => { + console.log(subgroupStreamObject) + if (type === 'video') { + if ( + subgroupStreamObject.objectPayloadLength === 0 || + subgroupStreamObject.object_status === 'EndOfGroup' || + subgroupStreamObject.object_status === 'EndOfTrackAndGroup' || + subgroupStreamObject.object_status === 'EndOfTrack' + ) { + console.log(subgroupStreamObject) + return + } + videoDecoderWorker.postMessage({ subgroupStreamObject }) + } else { + audioDecoderWorker.postMessage({ subgroupStreamObject }) + } + }) +} + +function setupButtonClickHandler(client: MOQTClient) { + sendSetupButtonClickHandler(client) + sendSubscribeButtonClickHandler(client) +} + +init().then(async () => { + const connectBtn = document.getElementById('connectBtn') as HTMLButtonElement + connectBtn.addEventListener('click', async () => { + const form = getFormElement() + const url = form.url.value + const client = new MOQTClient(url) + setupClientCallbacks(client) + setupButtonClickHandler(client) + await client.start() + }) +}) diff --git a/js/examples/media/subscriber/utils.ts b/js/examples/media/subscriber/utils.ts new file mode 100644 index 00000000..b066e725 --- /dev/null +++ b/js/examples/media/subscriber/utils.ts @@ -0,0 +1,3 @@ +export const getFormElement = (): HTMLFormElement => { + return document.getElementById('form') as HTMLFormElement +} diff --git a/js/examples/media/subscriber/videoDecoder.ts b/js/examples/media/subscriber/videoDecoder.ts new file mode 100644 index 00000000..9f5be15d --- /dev/null +++ b/js/examples/media/subscriber/videoDecoder.ts @@ -0,0 +1,65 @@ +function sendVideoFrameMessage(frame: VideoFrame): void { + self.postMessage({ frame }) + frame.close() +} + +let videoDecoder: VideoDecoder | undefined +async function initializeVideoDecoder() { + const init: VideoDecoderInit = { + output: sendVideoFrameMessage, + error: (e: any) => { + console.log(e.message) + videoDecoder = undefined + } + } + const config = { + codec: 'av01.0.04M.08', + width: 640, + height: 480, + scalabilityMode: 'L1T3' + } + const decoder = new VideoDecoder(init) + decoder.configure(config) + return decoder +} + +namespace VideoDecoder { + export type SubgroupStreamObject = { + objectId: number + objectPayloadLength: number + objectPayload: Uint8Array + objectStatus: any + } +} + +let keyframeDecoded = false +self.onmessage = async (event) => { + const subgroupStreamObject: VideoDecoder.SubgroupStreamObject = { + objectId: event.data.subgroupStreamObject.object_id, + objectPayloadLength: event.data.subgroupStreamObject.object_payload_length, + objectPayload: event.data.subgroupStreamObject.object_payload, + objectStatus: event.data.subgroupStreamObject.object_status + } + + const chunkArray = new Uint8Array(subgroupStreamObject.objectPayload) + const decoder = new TextDecoder() + const jsonString = decoder.decode(chunkArray) + const objectPayload = JSON.parse(jsonString) + + const encodedVideoChunk = new EncodedVideoChunk({ + type: objectPayload.chunk.type, + timestamp: objectPayload.chunk.timestamp, + duration: objectPayload.chunk.duration, + data: new Uint8Array(objectPayload.chunk.data) + }) + + if (!videoDecoder) { + videoDecoder = await initializeVideoDecoder() + // The first frame after initializing the decoder must be a keyframe + if (objectPayload.chunk.type !== 'key') { + return + } + } + + await videoDecoder.decode(encodedVideoChunk) +} diff --git a/js/examples/message/index.html b/js/examples/message/index.html new file mode 100644 index 00000000..e2b4563b --- /dev/null +++ b/js/examples/message/index.html @@ -0,0 +1,164 @@ + + + + + + MOQT Message Test + + +
+
+
+

Connection

+ + + + +
+
+
+

Forwarding Preference

+ + +
+ *You must select at first +

Message

+ +
+

CLIENT_SETUP

+ +
+ Max Subscribe ID: +
+ + +

ANNOUNCE

+ +
+ + +

SUBSCRIBE_ANNOUNCES

+ +
+ + +

SUBSCRIBE

+ +
+ +
+ +
+ +
+
+ +
+ Group Order: + + + +
+
+ Filter Type: + + + + +
+ Start Group: Start Object: + +
+ End Group: End Object: + +
+ + +

UNSUBSCRIBE

+ +
+ + +

OBJECT

+
+

Header

+
+ +
+
+ Group ID: +

0

+
+ + +
+ Subgroup ID: +

0

+
+ + +
+
+ +
+
+

Object

+
+ + Object ID: +

0

+
+
+ Object Status: + + + + + +
+
+ +
+ + +
+ +
+
+ +
+
+
+
+

Received

+
+
+
+
+ + + + + diff --git a/js/main.js b/js/examples/message/main.js similarity index 57% rename from js/main.js rename to js/examples/message/main.js index b3ee6a83..21d34ad7 100644 --- a/js/main.js +++ b/js/examples/message/main.js @@ -1,14 +1,13 @@ -import init, { MOQTClient } from './pkg/moqt_client_sample' +import init, { MOQTClient } from '../../pkg/moqt_client_sample' // TODO: impl close init().then(async () => { console.log('init wasm-pack') - let trackHeaderSent = false const subgroupHeaderSent = new Set() let objectId = 0n - let mutableGroupId = 0n - let mutableSubgroupId = 0n + let groupId = 0n + let subgroupId = 0n const connectBtn = document.getElementById('connectBtn') connectBtn.addEventListener('click', async () => { @@ -22,7 +21,7 @@ init().then(async () => { const describeReceivedObject = (payload) => { // change line - let brElement = document.createElement('br') + const brElement = document.createElement('br') receivedTextElement.prepend(brElement) // decode the object array to its text @@ -30,7 +29,7 @@ init().then(async () => { const receivedText = new TextDecoder().decode(receivedArray) // show received text - let receivedElement = document.createElement('p') + const receivedElement = document.createElement('p') receivedElement.textContent = receivedText receivedTextElement.prepend(receivedElement) } @@ -41,7 +40,7 @@ init().then(async () => { client.onAnnounce(async (announceMessage) => { console.log({ announceMessage }) - let announcedNamespace = announceMessage.track_namespace + const announcedNamespace = announceMessage.track_namespace await client.sendAnnounceOkMessage(announcedNamespace) }) @@ -53,17 +52,17 @@ init().then(async () => { client.onSubscribe(async (subscribeMessage, isSuccess, code) => { console.log({ subscribeMessage }) - let receivedSubscribeId = BigInt(subscribeMessage.subscribe_id) - let receivedTrackAlias = BigInt(subscribeMessage.track_alias) + const receivedSubscribeId = BigInt(subscribeMessage.subscribe_id) + const receivedTrackAlias = BigInt(subscribeMessage.track_alias) console.log('subscribeId', receivedSubscribeId, 'trackAlias', receivedTrackAlias) if (isSuccess) { - let expire = 0n + const expire = 0n const forwardingPreference = Array.from(form['forwarding-preference']).filter((elem) => elem.checked)[0].value await client.sendSubscribeOkMessage(receivedSubscribeId, expire, authInfo, forwardingPreference) } else { // TODO: set accurate reasonPhrase - let reasonPhrase = 'subscribe error' + const reasonPhrase = 'subscribe error' await client.sendSubscribeError(subscribeMessage.subscribe_id, code, reasonPhrase) } }) @@ -72,8 +71,8 @@ init().then(async () => { console.log({ subscribeResponse }) }) - client.onSubscribeNamespaceResponse(async (subscribeNamespaceResponse) => { - console.log({ subscribeNamespaceResponse }) + client.onSubscribeAnnouncesResponse(async (subscribeAnnouncesResponse) => { + console.log({ subscribeAnnouncesResponse }) }) client.onUnsubscribe(async (unsubscribeMessage) => { @@ -85,37 +84,36 @@ init().then(async () => { describeReceivedObject(datagramObject.object_payload) }) - client.onTrackStreamHeader(async (trackStreamHeader) => { - console.log({ trackStreamHeader }) - }) - - client.onTrackStreamObject(async (trackStreamObject) => { - console.log({ trackStreamObject }) - describeReceivedObject(trackStreamObject.object_payload) + client.onDatagramObjectStatus(async (datagramObjectStatus) => { + console.log({ datagramObjectStatus }) }) client.onSubgroupStreamHeader(async (subgroupStreamHeader) => { console.log({ subgroupStreamHeader }) }) - client.onSubgroupStreamObject(async (subgroupStreamObject) => { + const trackAlias = form['subscribe-track-alias'].value + client.onSubgroupStreamObject(BigInt(trackAlias), async (subgroupStreamObject) => { console.log({ subgroupStreamObject }) - describeReceivedObject(subgroupStreamObject.object_payload) + const objectPayload = subgroupStreamObject.object_payload + + if (objectPayload.length > 0) { + describeReceivedObject(objectPayload) + } }) const objectIdElement = document.getElementById('objectId') - const mutableDatagramAndTrackGroupIdElement = document.getElementById('mutableDatagramAndTrackGroupId') - const mutableSubgroupGroupIdElement = document.getElementById('mutableSubgroupGroupId') - const mutableSubgroupIdElement = document.getElementById('mutableSubgroupId') + const datagramGroupIdElement = document.getElementById('datagramGroupId') + const subgroupGroupIdElement = document.getElementById('subgroupGroupId') + const subgroupIdElement = document.getElementById('subgroupId') const sendSetupBtn = document.getElementById('sendSetupBtn') sendSetupBtn.addEventListener('click', async () => { console.log('send setup btn clicked') - const role = Array.from(form['role']).filter((elem) => elem.checked)[0].value const versions = form['versions'].value.split(',').map(BigInt) const maxSubscribeId = form['max-subscribe-id'].value - await client.sendSetupMessage(role, versions, BigInt(maxSubscribeId)) + await client.sendSetupMessage(versions, BigInt(maxSubscribeId)) }) const sendAnnounceBtn = document.getElementById('sendAnnounceBtn') @@ -127,13 +125,13 @@ init().then(async () => { await client.sendAnnounceMessage(trackNamespace, authInfo) }) - const sendSubscribeNamespaceBtn = document.getElementById('sendSubscribeNamespaceBtn') - sendSubscribeNamespaceBtn.addEventListener('click', async () => { - console.log('send subscribe namespace btn clicked') + const sendSubscribeAnnouncesBtn = document.getElementById('sendSubscribeAnnouncesBtn') + sendSubscribeAnnouncesBtn.addEventListener('click', async () => { + console.log('send subscribe announces btn clicked') const trackNamespacePrefix = form['track-namespace-prefix'].value.split('/') const authInfo = form['auth-info'].value - await client.sendSubscribeNamespaceMessage(trackNamespacePrefix, authInfo) + await client.sendSubscribeAnnouncesMessage(trackNamespacePrefix, authInfo) }) const sendSubscribeBtn = document.getElementById('sendSubscribeBtn') @@ -149,9 +147,20 @@ init().then(async () => { const startGroup = form['start-group'].value const startObject = form['start-object'].value const endGroup = form['end-group'].value - const endObject = form['end-object'].value const authInfo = form['auth-info'].value + console.log( + subscribeId, + trackAlias, + trackNamespace, + trackName, + subscriberPriority, + groupOrder, + filterType, + startGroup, + startObject, + endGroup + ) await client.sendSubscribeMessage( BigInt(subscribeId), @@ -164,7 +173,6 @@ init().then(async () => { BigInt(startGroup), BigInt(startObject), BigInt(endGroup), - BigInt(endObject), authInfo ) }) @@ -179,7 +187,6 @@ init().then(async () => { const sendDatagramObjectBtn = document.getElementById('sendDatagramObjectBtn') sendDatagramObjectBtn.addEventListener('click', async () => { console.log('send datagram object btn clicked') - const subscribeId = form['object-subscribe-id'].value const trackAlias = form['object-track-alias'].value const publisherPriority = form['publisher-priority'].value const objectPayloadString = form['object-payload'].value @@ -187,136 +194,140 @@ init().then(async () => { // encode the text to the object array const objectPayloadArray = new TextEncoder().encode(objectPayloadString) - await client.sendDatagramObject( - BigInt(subscribeId), - BigInt(trackAlias), - mutableGroupId, - objectId++, - publisherPriority, - objectPayloadArray - ) + await client.sendDatagramObject(BigInt(trackAlias), groupId, objectId++, publisherPriority, objectPayloadArray) objectIdElement.textContent = objectId }) - const sendTrackObjectBtn = document.getElementById('sendTrackObjectBtn') - sendTrackObjectBtn.addEventListener('click', async () => { - console.log('send track stream object btn clicked') - const subscribeId = form['object-subscribe-id'].value + const sendDatagramObjectWithStatusBtn = document.getElementById('sendDatagramObjectWithStatusBtn') + sendDatagramObjectWithStatusBtn.addEventListener('click', async () => { + console.log('send datagram object with status btn clicked') + const trackAlias = form['object-track-alias'].value + const publisherPriority = form['publisher-priority'].value + const objectStatus = Array.from(form['object-status']).filter((elem) => elem.checked)[0].value + + await client.sendDatagramObjectStatus(BigInt(trackAlias), groupId, objectId++, publisherPriority, objectStatus) + objectIdElement.textContent = objectId + }) + + const sendSubgroupObjectBtn = document.getElementById('sendSubgroupObjectBtn') + sendSubgroupObjectBtn.addEventListener('click', async () => { + console.log('send subgroup stream object btn clicked') const trackAlias = form['object-track-alias'].value const publisherPriority = form['publisher-priority'].value const objectPayloadString = form['object-payload'].value // encode the text to the object array const objectPayloadArray = new TextEncoder().encode(objectPayloadString) + const key = `${groupId}:${subgroupId}` // send header if it is the first time - if (!trackHeaderSent) { - await client.sendTrackStreamHeaderMessage(BigInt(subscribeId), BigInt(trackAlias), publisherPriority) - trackHeaderSent = true + if (!subgroupHeaderSent.has(key)) { + await client.sendSubgroupStreamHeaderMessage(BigInt(trackAlias), groupId, subgroupId, publisherPriority) + subgroupHeaderSent.add(key) } - await client.sendTrackStreamObject(BigInt(subscribeId), mutableGroupId, objectId++, objectPayloadArray) + await client.sendSubgroupStreamObject( + BigInt(trackAlias), + groupId, + subgroupId, + objectId++, + undefined, + objectPayloadArray + ) objectIdElement.textContent = objectId }) - const sendSubgroupObjectBtn = document.getElementById('sendSubgroupObjectBtn') - sendSubgroupObjectBtn.addEventListener('click', async () => { - console.log('send subgroup stream object btn clicked') - const subscribeId = form['object-subscribe-id'].value + const sendSubgroupObjectWithStatusBtn = document.getElementById('sendSubgroupObjectWithStatusBtn') + sendSubgroupObjectWithStatusBtn.addEventListener('click', async () => { + console.log('send subgroup stream object with status btn clicked') const trackAlias = form['object-track-alias'].value const publisherPriority = form['publisher-priority'].value - const objectPayloadString = form['object-payload'].value + const objectStatus = Array.from(form['object-status']).filter((elem) => elem.checked)[0].value - // encode the text to the object array - const objectPayloadArray = new TextEncoder().encode(objectPayloadString) - const key = `${mutableGroupId}:${mutableSubgroupId}` + const objectPayloadArray = Uint8Array.from([]) + const key = `${groupId}:${subgroupId}` // send header if it is the first time if (!subgroupHeaderSent.has(key)) { - await client.sendSubgroupStreamHeaderMessage( - BigInt(subscribeId), - BigInt(trackAlias), - mutableGroupId, - mutableSubgroupId, - publisherPriority - ) + await client.sendSubgroupStreamHeaderMessage(BigInt(trackAlias), groupId, subgroupId, publisherPriority) subgroupHeaderSent.add(key) } await client.sendSubgroupStreamObject( - subscribeId, - mutableGroupId, - mutableSubgroupId, + BigInt(trackAlias), + groupId, + subgroupId, objectId++, + objectStatus, objectPayloadArray ) objectIdElement.textContent = objectId }) - const ascendMutableDatagramAndTrackGroupId = document.getElementById('ascendMutableDatagramAndTrackGroupIdBtn') - ascendMutableDatagramAndTrackGroupId.addEventListener('click', async () => { - mutableGroupId++ + const ascendDatagramGroupId = document.getElementById('ascendDatagramGroupIdBtn') + ascendDatagramGroupId.addEventListener('click', async () => { + groupId++ objectId = 0n - console.log('ascend mutableGroupId', mutableGroupId) - mutableDatagramAndTrackGroupIdElement.textContent = mutableGroupId + console.log('ascend groupId', groupId) + datagramGroupIdElement.textContent = groupId objectIdElement.textContent = objectId }) - const descendMutableDatagramAndTrackGroupId = document.getElementById('descendMutableDatagramAndTrackGroupIdBtn') - descendMutableDatagramAndTrackGroupId.addEventListener('click', async () => { - if (mutableGroupId === 0n) { + const descendDatagramGroupId = document.getElementById('descendDatagramGroupIdBtn') + descendDatagramGroupId.addEventListener('click', async () => { + if (groupId === 0n) { return } - mutableGroupId-- + groupId-- objectId = 0n - console.log('descend mutableGroupId', mutableGroupId) - mutableDatagramAndTrackGroupIdElement.textContent = mutableGroupId + console.log('descend groupId', groupId) + datagramGroupIdElement.textContent = groupId objectIdElement.textContent = objectId }) - const ascendMutableSubgroupGroupId = document.getElementById('ascendMutableSubgroupGroupIdBtn') - ascendMutableSubgroupGroupId.addEventListener('click', async () => { - mutableGroupId++ - mutableSubgroupId = 0n + const ascendSubgroupGroupId = document.getElementById('ascendSubgroupGroupIdBtn') + ascendSubgroupGroupId.addEventListener('click', async () => { + groupId++ + subgroupId = 0n objectId = 0n - console.log('ascend mutableGroupId', mutableGroupId) - mutableSubgroupGroupIdElement.textContent = mutableGroupId - mutableSubgroupIdElement.textContent = mutableSubgroupId + console.log('ascend groupId', groupId) + subgroupGroupIdElement.textContent = groupId + subgroupIdElement.textContent = subgroupId objectIdElement.textContent = objectId }) - const descendMutableSubgroupGroupId = document.getElementById('descendMutableSubgroupGroupIdBtn') - descendMutableSubgroupGroupId.addEventListener('click', async () => { - if (mutableGroupId === 0n) { + const descendSubgroupGroupId = document.getElementById('descendSubgroupGroupIdBtn') + descendSubgroupGroupId.addEventListener('click', async () => { + if (groupId === 0n) { return } - mutableGroupId-- - mutableSubgroupId = 0n + groupId-- + subgroupId = 0n objectId = 0n - console.log('descend mutableGroupId', mutableGroupId) - mutableSubgroupGroupIdElement.textContent = mutableGroupId - mutableSubgroupIdElement.textContent = mutableSubgroupId + console.log('descend groupId', groupId) + subgroupGroupIdElement.textContent = groupId + subgroupIdElement.textContent = subgroupId objectIdElement.textContent = objectId }) - const ascendMutableSubgroupId = document.getElementById('ascendMutableSubgroupIdBtn') - ascendMutableSubgroupId.addEventListener('click', async () => { - mutableSubgroupId++ + const ascendSubgroupId = document.getElementById('ascendSubgroupIdBtn') + ascendSubgroupId.addEventListener('click', async () => { + subgroupId++ objectId = 0n - console.log('ascend mutableSubgroupId', mutableSubgroupId) - mutableSubgroupIdElement.textContent = mutableSubgroupId + console.log('ascend subgroupId', subgroupId) + subgroupIdElement.textContent = subgroupId objectIdElement.textContent = objectId }) - const descendMutableSubroupId = document.getElementById('descendMutableSubgroupIdBtn') - descendMutableSubroupId.addEventListener('click', async () => { - if (mutableSubgroupId === 0n) { + const descendSubgroupId = document.getElementById('descendSubgroupIdBtn') + descendSubgroupId.addEventListener('click', async () => { + if (subgroupId === 0n) { return } - mutableSubgroupId-- + subgroupId-- objectId = 0n - console.log('descend mutableSubgroupId', mutableSubgroupId) - mutableSubgroupIdElement.textContent = mutableSubgroupId + console.log('descend subgroupId', subgroupId) + subgroupIdElement.textContent = subgroupId objectIdElement.textContent = objectId }) @@ -325,10 +336,11 @@ init().then(async () => { const forwardingPreference = document.querySelectorAll('input[name="forwarding-preference"]') const subgroupHeaderContents = document.getElementById('subgroupHeaderContents') - const datagramAndTrackObjectContents = document.getElementById('datagramAndTrackObjectContents') + const datagramObjectContents = document.getElementById('datagramObjectContents') const sendDatagramObject = document.getElementById('sendDatagramObject') - const sendTrackObject = document.getElementById('sendTrackObject') + const sendDatagramObjectWithStatus = document.getElementById('sendDatagramObjectWithStatus') const sendSubgroupObject = document.getElementById('sendSubgroupObject') + const sendSubgroupObjectWithStatus = document.getElementById('sendSubgroupObjectWithStatus') const headerField = document.getElementById('headerField') const objectField = document.getElementById('objectField') @@ -336,27 +348,21 @@ init().then(async () => { forwardingPreference.forEach((elem) => { elem.addEventListener('change', async () => { if (elem.value === 'datagram') { - datagramAndTrackObjectContents.style.display = 'block' + datagramObjectContents.style.display = 'block' subgroupHeaderContents.style.display = 'none' sendDatagramObject.style.display = 'block' - sendTrackObject.style.display = 'none' + sendDatagramObjectWithStatus.style.display = 'block' sendSubgroupObject.style.display = 'none' + sendSubgroupObjectWithStatus.style.display = 'none' headerField.style.display = 'none' objectField.style.display = 'none' - } else if (elem.value === 'track') { - datagramAndTrackObjectContents.style.display = 'block' - subgroupHeaderContents.style.display = 'none' - sendDatagramObject.style.display = 'none' - sendTrackObject.style.display = 'block' - sendSubgroupObject.style.display = 'none' - headerField.style.display = 'block' - objectField.style.display = 'block' } else if (elem.value === 'subgroup') { - datagramAndTrackObjectContents.style.display = 'none' + datagramObjectContents.style.display = 'none' subgroupHeaderContents.style.display = 'block' sendDatagramObject.style.display = 'none' - sendTrackObject.style.display = 'none' + sendDatagramObjectWithStatus.style.display = 'none' sendSubgroupObject.style.display = 'block' + sendSubgroupObjectWithStatus.style.display = 'block' headerField.style.display = 'block' objectField.style.display = 'block' } diff --git a/js/index.html b/js/index.html index 368a29a0..1ec4d1ff 100644 --- a/js/index.html +++ b/js/index.html @@ -1,161 +1,15 @@ - + - MOQT Message Test + Examples -
-
-
-

Connection

- - - - -
-
-
-

Forwarding Preference

- - - -
- *You must select at first -

Message

- -
-

CLIENT_SETUP

- -
- Max Subscribe ID: -
- Role: - - - -
- - -

ANNOUNCE

- -
- - -

SUBSCRIBE_NAMESPACE

- -
- - -

SUBSCRIBE

- -
- -
- -
- -
-
- -
- Group Order: - - - -
-
- Filter Type: - - - - -
- Start Group: Start Object: - -
- End Group: End Object: - -
- - -

UNSUBSCRIBE

- -
- - -

OBJECT

-
-

Header

-
- -
- -
- - -
-
-

Object

-
-
- Group ID: -

0

-
- - -
-
-
- Object ID: -

0

-
-
- -
- -
- -
- -
-
-
-

Received

-
-
-
+

Examples

+ - - - diff --git a/js/package-lock.json b/js/package-lock.json index 35d04070..4c021e5d 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -8,6 +8,9 @@ "name": "js", "version": "1.0.0", "license": "ISC", + "dependencies": { + "@types/dom-mediacapture-transform": "^0.1.11" + }, "devDependencies": { "prettier": "^3.2.5", "vite": "^4.4.10", @@ -22,6 +25,7 @@ "arm" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "android" @@ -38,6 +42,7 @@ "arm64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "android" @@ -54,6 +59,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "android" @@ -70,6 +76,7 @@ "arm64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "darwin" @@ -86,6 +93,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "darwin" @@ -102,6 +110,7 @@ "arm64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "freebsd" @@ -118,6 +127,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "freebsd" @@ -134,6 +144,7 @@ "arm" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -150,6 +161,7 @@ "arm64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -166,6 +178,7 @@ "ia32" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -182,6 +195,7 @@ "loong64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -198,6 +212,7 @@ "mips64el" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -214,6 +229,7 @@ "ppc64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -230,6 +246,7 @@ "riscv64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -246,6 +263,7 @@ "s390x" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -262,6 +280,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "linux" @@ -278,6 +297,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "netbsd" @@ -294,6 +314,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "openbsd" @@ -310,6 +331,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "sunos" @@ -326,6 +348,7 @@ "arm64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "win32" @@ -342,6 +365,7 @@ "ia32" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "win32" @@ -358,6 +382,7 @@ "x64" ], "dev": true, + "license": "MIT", "optional": true, "os": [ "win32" @@ -366,11 +391,27 @@ "node": ">=12" } }, + "node_modules/@types/dom-mediacapture-transform": { + "version": "0.1.11", + "resolved": "https://registry.npmjs.org/@types/dom-mediacapture-transform/-/dom-mediacapture-transform-0.1.11.tgz", + "integrity": "sha512-Y2p+nGf1bF2XMttBnsVPHUWzRRZzqUoJAKmiP10b5umnO6DDrWI0BrGDJy1pOHoOULVmGSfFNkQrAlC5dcj6nQ==", + "license": "MIT", + "dependencies": { + "@types/dom-webcodecs": "*" + } + }, + "node_modules/@types/dom-webcodecs": { + "version": "0.1.14", + "resolved": "https://registry.npmjs.org/@types/dom-webcodecs/-/dom-webcodecs-0.1.14.tgz", + "integrity": "sha512-ba9aF0qARLLQpLihONIRbj8VvAdUxO+5jIxlscVcDAQTcJmq5qVr781+ino5qbQUJUmO21cLP2eLeXYWzao5Vg==", + "license": "MIT" + }, "node_modules/ansi-styles": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", "dev": true, + "license": "MIT", "dependencies": { "color-convert": "^2.0.1" }, @@ -386,6 +427,7 @@ "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", "dev": true, + "license": "MIT", "dependencies": { "ansi-styles": "^4.1.0", "supports-color": "^7.1.0" @@ -402,6 +444,7 @@ "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", "dev": true, + "license": "MIT", "dependencies": { "color-name": "~1.1.4" }, @@ -413,7 +456,8 @@ "version": "1.1.4", "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/esbuild": { "version": "0.18.20", @@ -421,6 +465,7 @@ "integrity": "sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==", "dev": true, "hasInstallScript": true, + "license": "MIT", "bin": { "esbuild": "bin/esbuild" }, @@ -457,6 +502,7 @@ "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-10.1.0.tgz", "integrity": "sha512-oRXApq54ETRj4eMiFzGnHWGy+zo5raudjuxN0b8H7s/RU2oW0Wvsx9O0ACRN/kRq9E8Vu/ReskGB5o3ji+FzHQ==", "dev": true, + "license": "MIT", "dependencies": { "graceful-fs": "^4.2.0", "jsonfile": "^6.0.1", @@ -472,6 +518,7 @@ "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "dev": true, "hasInstallScript": true, + "license": "MIT", "optional": true, "os": [ "darwin" @@ -484,13 +531,15 @@ "version": "4.2.11", "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", - "dev": true + "dev": true, + "license": "ISC" }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=8" } @@ -500,6 +549,7 @@ "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.1.0.tgz", "integrity": "sha512-5dgndWOriYSm5cnYaJNhalLNDKOqFwyDB/rr1E9ZsGciGvKPs8R2xYGCacuf3z6K1YKDz182fd+fY3cn3pMqXQ==", "dev": true, + "license": "MIT", "dependencies": { "universalify": "^2.0.0" }, @@ -508,9 +558,9 @@ } }, "node_modules/nanoid": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.6.tgz", - "integrity": "sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==", + "version": "3.3.8", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", + "integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==", "dev": true, "funding": [ { @@ -518,6 +568,7 @@ "url": "https://github.com/sponsors/ai" } ], + "license": "MIT", "bin": { "nanoid": "bin/nanoid.cjs" }, @@ -529,18 +580,20 @@ "version": "1.5.0", "resolved": "https://registry.npmjs.org/narrowing/-/narrowing-1.5.0.tgz", "integrity": "sha512-DUu4XdKgkfAPTAL28k79pdnshDE2W5T24QAnidSPo2F/W1TX6CjNzmEeXQfE5O1lxQvC0GYI6ZRDsLcyzugEYA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", - "dev": true + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "dev": true, + "license": "ISC" }, "node_modules/postcss": { - "version": "8.4.31", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz", - "integrity": "sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==", + "version": "8.5.3", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", + "integrity": "sha512-dle9A3yYxlBSrt8Fu+IpjGT8SY8hN0mlaA6GY8t0P5PjIOZemULz/E2Bnm/2dcUOena75OTNkHI76uZBNUUq3A==", "dev": true, "funding": [ { @@ -556,20 +609,22 @@ "url": "https://github.com/sponsors/ai" } ], + "license": "MIT", "dependencies": { - "nanoid": "^3.3.6", - "picocolors": "^1.0.0", - "source-map-js": "^1.0.2" + "nanoid": "^3.3.8", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" }, "engines": { "node": "^10 || ^12 || >=14" } }, "node_modules/prettier": { - "version": "3.2.5", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.2.5.tgz", - "integrity": "sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==", + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.5.3.tgz", + "integrity": "sha512-QQtaxnoDJeAkDvDKWCLiwIXkTgRhwYDEQCghU9Z6q03iyek/rxRh/2lC3HB7P8sWT2xC/y5JDctPLBIGzHKbhw==", "dev": true, + "license": "MIT", "bin": { "prettier": "bin/prettier.cjs" }, @@ -581,10 +636,11 @@ } }, "node_modules/rollup": { - "version": "3.29.4", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-3.29.4.tgz", - "integrity": "sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==", + "version": "3.29.5", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-3.29.5.tgz", + "integrity": "sha512-GVsDdsbJzzy4S/v3dqWPJ7EfvZJfCHiDqe80IyrF59LYuP+e6U1LJoUqeuqRbwAWoMNoXivMNeNAOf5E22VA1w==", "dev": true, + "license": "MIT", "bin": { "rollup": "dist/bin/rollup" }, @@ -597,10 +653,11 @@ } }, "node_modules/source-map-js": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", - "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", "dev": true, + "license": "BSD-3-Clause", "engines": { "node": ">=0.10.0" } @@ -610,6 +667,7 @@ "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", "dev": true, + "license": "MIT", "dependencies": { "has-flag": "^4.0.0" }, @@ -618,19 +676,21 @@ } }, "node_modules/universalify": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.0.tgz", - "integrity": "sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.1.tgz", + "integrity": "sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 10.0.0" } }, "node_modules/vite": { - "version": "4.4.10", - "resolved": "https://registry.npmjs.org/vite/-/vite-4.4.10.tgz", - "integrity": "sha512-TzIjiqx9BEXF8yzYdF2NTf1kFFbjMjUSV0LFZ3HyHoI3SGSPLnnFUKiIQtL3gl2AjHvMrprOvQ3amzaHgQlAxw==", + "version": "4.5.9", + "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.9.tgz", + "integrity": "sha512-qK9W4xjgD3gXbC0NmdNFFnVFLMWSNiR3swj957yutwzzN16xF/E7nmtAyp1rT9hviDroQANjE4HK3H4WqWdFtw==", "dev": true, + "license": "MIT", "dependencies": { "esbuild": "^0.18.10", "postcss": "^8.4.27", @@ -686,6 +746,7 @@ "resolved": "https://registry.npmjs.org/vite-plugin-wasm-pack/-/vite-plugin-wasm-pack-0.1.12.tgz", "integrity": "sha512-WliYvQp9HXluir4OKGbngkcKxtYtifU11cqLurRRJGsl770Sjr1iIkp5RuvU3IC1poT4A57Z2/YgAKI2Skm7ZA==", "dev": true, + "license": "MIT", "dependencies": { "chalk": "^4.1.2", "fs-extra": "^10.0.0", diff --git a/js/package.json b/js/package.json index 28bfff5e..2af9c9a5 100644 --- a/js/package.json +++ b/js/package.json @@ -16,5 +16,8 @@ "prettier": "^3.2.5", "vite": "^4.4.10", "vite-plugin-wasm-pack": "^0.1.12" + }, + "dependencies": { + "@types/dom-mediacapture-transform": "^0.1.11" } } diff --git a/js/tsconfig.json b/js/tsconfig.json new file mode 100644 index 00000000..5a862bef --- /dev/null +++ b/js/tsconfig.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "target": "esnext", + "module": "esnext", + "moduleResolution": "node", + "strict": true, + "jsx": "preserve", + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "types": ["@types/dom-mediacapture-transform"] + }, + "include": ["examples/**/*.ts", "examples/**/*.tsx"], + "exclude": ["node_modules"] +} diff --git a/js/vite.config.js b/js/vite.config.js index 9a6b86ba..5ff904c2 100644 --- a/js/vite.config.js +++ b/js/vite.config.js @@ -1,11 +1,12 @@ import { defineConfig } from 'vite' -import wasmPack from 'vite-plugin-wasm-pack' export default defineConfig({ - // pass your local crate path to the plugin - // plugins: [wasmPack('../rust-wasm-vite')] - // server: { - // headers: { - // } - // } + build: { + rollupOptions: { + input: { + main: 'src/pages/index.html', + about: 'src/pages/about.html' + } + } + } }) diff --git a/moqt-client-sample/src/lib.rs b/moqt-client-sample/src/lib.rs index 3c3f7aa9..2fbc859d 100644 --- a/moqt-client-sample/src/lib.rs +++ b/moqt-client-sample/src/lib.rs @@ -13,20 +13,25 @@ use moqt_core::{ announce_error::AnnounceError, announce_ok::AnnounceOk, client_setup::ClientSetup, + group_order::GroupOrder, server_setup::ServerSetup, - setup_parameters::{MaxSubscribeID, Role, RoleCase, SetupParameter}, - subscribe::{FilterType, GroupOrder, Subscribe}, + setup_parameters::{MaxSubscribeID, SetupParameter}, + subscribe::{FilterType, Subscribe}, + subscribe_announces::SubscribeAnnounces, + subscribe_announces_error::SubscribeAnnouncesError, + subscribe_announces_ok::SubscribeAnnouncesOk, subscribe_error::{SubscribeError, SubscribeErrorCode}, - subscribe_namespace::SubscribeNamespace, - subscribe_namespace_error::SubscribeNamespaceError, - subscribe_namespace_ok::SubscribeNamespaceOk, subscribe_ok::SubscribeOk, unannounce::UnAnnounce, unsubscribe::Unsubscribe, - version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, + version_specific_parameters::{ + AuthorizationInfo, DeliveryTimeout, MaxCacheDuration, VersionSpecificParameter, + }, }, messages::{ - data_streams::{datagram, subgroup_stream, track_stream, DataStreams}, + data_streams::{ + datagram, datagram_status, object_status::ObjectStatus, subgroup_stream, DataStreams, + }, moqt_payload::MOQTPayload, }, models::subscriptions::{ @@ -147,11 +152,11 @@ impl MOQTClient { .borrow_mut() .set_subscribe_response_callback(callback); } - #[wasm_bindgen(js_name = onSubscribeNamespaceResponse)] - pub fn set_subscribe_namespace_response_callback(&mut self, callback: js_sys::Function) { + #[wasm_bindgen(js_name = onSubscribeAnnouncesResponse)] + pub fn set_subscribe_announces_response_callback(&mut self, callback: js_sys::Function) { self.callbacks .borrow_mut() - .set_subscribe_namespace_response_callback(callback); + .set_subscribe_announces_response_callback(callback); } #[wasm_bindgen(js_name = onDatagramObject)] @@ -161,18 +166,11 @@ impl MOQTClient { .set_datagram_object_callback(callback); } - #[wasm_bindgen(js_name = onTrackStreamHeader)] - pub fn set_track_stream_header_callback(&mut self, callback: js_sys::Function) { + #[wasm_bindgen(js_name = onDatagramObjectStatus)] + pub fn set_datagram_object_status_callback(&mut self, callback: js_sys::Function) { self.callbacks .borrow_mut() - .set_track_stream_header_callback(callback); - } - - #[wasm_bindgen(js_name = onTrackStreamObject)] - pub fn set_track_stream_object_callback(&mut self, callback: js_sys::Function) { - self.callbacks - .borrow_mut() - .set_track_stream_object_callback(callback); + .set_datagram_object_status_callback(callback); } #[wasm_bindgen(js_name = onSubgroupStreamHeader)] @@ -183,33 +181,28 @@ impl MOQTClient { } #[wasm_bindgen(js_name = onSubgroupStreamObject)] - pub fn set_subgroup_stream_object_callback(&mut self, callback: js_sys::Function) { + pub fn set_subgroup_stream_object_callback( + &mut self, + track_alias: u64, + callback: js_sys::Function, + ) { self.callbacks .borrow_mut() - .set_subgroup_stream_object_callback(callback); + .set_subgroup_stream_object_callback(track_alias, callback); } #[wasm_bindgen(js_name = sendSetupMessage)] pub async fn send_setup_message( &mut self, - role_value: u8, versions: Vec, max_subscribe_id: u64, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { - let role = RoleCase::try_from(role_value).unwrap(); + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { let versions = versions.iter().map(|v| *v as u32).collect::>(); - let mut setup_parameters: Vec = - vec![SetupParameter::Role(Role::new(role))]; - - match role { - RoleCase::Publisher | RoleCase::PubSub => { - setup_parameters.push(SetupParameter::MaxSubscribeID(MaxSubscribeID::new( - max_subscribe_id, - ))); - } - _ => {} - } + let setup_parameters = vec![SetupParameter::MaxSubscribeID(MaxSubscribeID::new( + max_subscribe_id, + ))]; let client_setup_message = ClientSetup::new(versions, setup_parameters); let mut client_setup_message_buf = BytesMut::new(); @@ -231,26 +224,13 @@ impl MOQTClient { // Setup nodes along with the role Ok(ok) => { log(std::format!("sent: client_setup: {:#x?}", client_setup_message).as_str()); - match role { - RoleCase::Publisher => { - self.subscription_node - .borrow_mut() - .setup_as_publisher(max_subscribe_id); - } - RoleCase::Subscriber => { - self.subscription_node - .borrow_mut() - .setup_as_subscriber(max_subscribe_id); - } - RoleCase::PubSub => { - self.subscription_node - .borrow_mut() - .setup_as_publisher(max_subscribe_id); - self.subscription_node - .borrow_mut() - .setup_as_subscriber(max_subscribe_id); - } - } + self.subscription_node + .borrow_mut() + .setup_as_publisher(max_subscribe_id); + self.subscription_node + .borrow_mut() + .setup_as_subscriber(max_subscribe_id); + Ok(ok) } Err(e) => Err(e), @@ -267,7 +247,8 @@ impl MOQTClient { track_namespace: js_sys::Array, auth_info: String, // param[0] ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { let auth_info_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(auth_info)); @@ -315,7 +296,8 @@ impl MOQTClient { &self, track_namespace: js_sys::Array, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { let length = track_namespace.length(); let mut track_namespace_vec: Vec = Vec::with_capacity(length as usize); for i in 0..length { @@ -359,7 +341,8 @@ impl MOQTClient { &self, track_namespace: js_sys::Array, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { // TODO: construct UnAnnounce let length = track_namespace.length(); let mut track_namespace_vec: Vec = Vec::with_capacity(length as usize); @@ -401,6 +384,7 @@ impl MOQTClient { // tmp impl #[wasm_bindgen(js_name = sendSubscribeMessage)] + #[allow(clippy::too_many_arguments)] pub async fn send_subscribe_message( &self, subscribe_id: u64, @@ -413,10 +397,10 @@ impl MOQTClient { start_group: u64, start_object: u64, end_group: u64, - end_object: u64, auth_info: String, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { // This is equal to `Now example` let auth_info = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(auth_info)); @@ -438,14 +422,18 @@ impl MOQTClient { (Some(start_group), Some(start_object)) } }; - let (end_group, end_object) = match filter_type { + let end_group = match filter_type { FilterType::LatestObject | FilterType::LatestGroup | FilterType::AbsoluteStart => { - (None, None) + None } - FilterType::AbsoluteRange => (Some(end_group), Some(end_object)), + FilterType::AbsoluteRange => Some(end_group), }; - let version_specific_parameters = vec![auth_info]; + let max_cache_duration = + VersionSpecificParameter::MaxCacheDuration(MaxCacheDuration::new(1000000)); + let delivery_timeout = + VersionSpecificParameter::DeliveryTimeout(DeliveryTimeout::new(100000)); + let version_specific_parameters = vec![auth_info, max_cache_duration, delivery_timeout]; let subscribe_message = Subscribe::new( subscribe_id, track_alias, @@ -457,7 +445,6 @@ impl MOQTClient { start_group, start_object, end_group, - end_object, version_specific_parameters, ) .unwrap(); @@ -493,7 +480,6 @@ impl MOQTClient { start_group, start_object, end_group, - end_object, ); Ok(ok) @@ -513,7 +499,8 @@ impl MOQTClient { auth_info: String, fowarding_preference: String, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { let auth_info = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(auth_info)); let subscription = self @@ -577,16 +564,14 @@ impl MOQTClient { *self.datagram_writer.borrow_mut() = Some(datagram_writer); } "track" => { - let send_uni_stream = WritableStream::from( - JsFuture::from( - self.transport - .borrow() - .as_ref() - .unwrap() - .create_unidirectional_stream(), - ) - .await?, - ); + let send_uni_stream = self + .transport + .borrow() + .as_ref() + .unwrap() + .create_unidirectional_stream(); + let send_uni_stream = + WritableStream::from(JsFuture::from(send_uni_stream).await?); let send_uni_stream_writer = send_uni_stream.get_writer()?; let writer_key = (subscribe_id, None); @@ -609,13 +594,14 @@ impl MOQTClient { } } - #[wasm_bindgen(js_name = sendSubscribeNamespaceMessage)] - pub async fn send_subscribe_namespace_message( + #[wasm_bindgen(js_name = sendSubscribeAnnouncesMessage)] + pub async fn send_subscribe_announces_message( &self, track_namespace_prefix: js_sys::Array, auth_info: String, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { let auth_info = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(auth_info)); let length = track_namespace_prefix.length(); @@ -629,21 +615,21 @@ impl MOQTClient { } let version_specific_parameters = vec![auth_info]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix_vec, version_specific_parameters); - let mut subscribe_namespace_message_buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut subscribe_namespace_message_buf); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix_vec, version_specific_parameters); + let mut subscribe_announces_message_buf = BytesMut::new(); + subscribe_announces_message.packetize(&mut subscribe_announces_message_buf); let mut buf = Vec::new(); // Message Type buf.extend(write_variable_integer( - u8::from(ControlMessageType::SubscribeNamespace) as u64, + u8::from(ControlMessageType::SubscribeAnnounces) as u64, )); // Message Payload and Payload Length buf.extend(write_variable_integer( - subscribe_namespace_message_buf.len() as u64, + subscribe_announces_message_buf.len() as u64, )); - buf.extend(subscribe_namespace_message_buf); + buf.extend(subscribe_announces_message_buf); let buffer = js_sys::Uint8Array::new_with_length(buf.len() as u32); buffer.copy_from(&buf); @@ -651,8 +637,8 @@ impl MOQTClient { match JsFuture::from(writer.write_with_chunk(&buffer)).await { Ok(ok) => { log(std::format!( - "sent: subscribe_namespace: {:#x?}", - subscribe_namespace_message + "sent: subscribe_announces: {:#x?}", + subscribe_announces_message ) .as_str()); Ok(ok) @@ -671,7 +657,8 @@ impl MOQTClient { error_code: u64, reason_phrase: String, ) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { // Find unused subscribe_id and track_alias automatically let valid_track_alias = self .subscription_node @@ -718,7 +705,8 @@ impl MOQTClient { #[wasm_bindgen(js_name = sendUnsubscribeMessage)] pub async fn send_unsubscribe_message(&self, subscribe_id: u64) -> Result { - if let Some(writer) = &*self.control_stream_writer.borrow() { + let writer = self.control_stream_writer.borrow().clone(); + if let Some(writer) = writer { let unsubscribe_message = Unsubscribe::new(subscribe_id); let mut unsubscribe_message_buf = BytesMut::new(); unsubscribe_message.packetize(&mut unsubscribe_message_buf); @@ -750,26 +738,26 @@ impl MOQTClient { #[wasm_bindgen(js_name = sendDatagramObject)] pub async fn send_datagram_object( &self, - subscribe_id: u64, track_alias: u64, group_id: u64, object_id: u64, publisher_priority: u8, object_payload: Vec, ) -> Result { - if let Some(writer) = &*self.datagram_writer.borrow() { + let writer = self.datagram_writer.borrow().clone(); + if let Some(writer) = writer { + let extension_headers = vec![]; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - None, + extension_headers, object_payload, ) .unwrap(); let mut datagram_object_buf = BytesMut::new(); - let _ = datagram_object.packetize(&mut datagram_object_buf); + datagram_object.packetize(&mut datagram_object_buf); let mut buf = Vec::new(); // Message Type @@ -791,69 +779,40 @@ impl MOQTClient { } } } else { - return Err(JsValue::from_str("datagram_writer is None")); + Err(JsValue::from_str("datagram_writer is None")) } } - #[wasm_bindgen(js_name = sendTrackStreamHeaderMessage)] - pub async fn send_track_stream_header_message( + #[wasm_bindgen(js_name = sendDatagramObjectStatus)] + pub async fn send_datagram_object_status( &self, - subscribe_id: u64, track_alias: u64, + group_id: u64, + object_id: u64, publisher_priority: u8, + object_status: u8, ) -> Result { - let stream_writers = self.stream_writers.borrow(); - let writer_key = (subscribe_id, None); - if let Some(writer) = stream_writers.get(&writer_key) { - let track_stream_header_message = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - let mut track_stream_header_message_buf = BytesMut::new(); - let _ = track_stream_header_message.packetize(&mut track_stream_header_message_buf); + let writer = self.datagram_writer.borrow().clone(); + if let Some(writer) = writer { + let extension_headers = vec![]; + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + ObjectStatus::try_from(object_status).unwrap(), + ) + .unwrap(); + let mut datagram_object_buf = BytesMut::new(); + datagram_object.packetize(&mut datagram_object_buf); let mut buf = Vec::new(); // Message Type buf.extend(write_variable_integer( - u8::from(DataStreamType::StreamHeaderTrack) as u64, + u8::from(DataStreamType::ObjectDatagramStatus) as u64, )); - buf.extend(track_stream_header_message_buf); - - let buffer = js_sys::Uint8Array::new_with_length(buf.len() as u32); - buffer.copy_from(&buf); - match JsFuture::from(writer.write_with_chunk(&buffer)).await { - Ok(ok) => { - log(std::format!( - "sent: track_stream_header: {:#x?}", - track_stream_header_message - ) - .as_str()); - Ok(ok) - } - Err(e) => Err(e), - } - } else { - return Err(JsValue::from_str("stream_writer is None")); - } - } - - #[wasm_bindgen(js_name = sendTrackStreamObject)] - pub async fn send_track_stream_object( - &self, - subscribe_id: u64, - group_id: u64, - object_id: u64, - object_payload: Vec, - ) -> Result { - let stream_writers = self.stream_writers.borrow(); - let writer_key = (subscribe_id, None); - if let Some(writer) = stream_writers.get(&writer_key) { - let track_stream_object = - track_stream::Object::new(group_id, object_id, None, object_payload).unwrap(); - let mut track_stream_object_buf = BytesMut::new(); - let _ = track_stream_object.packetize(&mut track_stream_object_buf); - - let mut buf = Vec::new(); - // Message Payload and Payload Length - buf.extend(track_stream_object_buf); + buf.extend(datagram_object_buf); let buffer = js_sys::Uint8Array::new_with_length(buf.len() as u32); buffer.copy_from(&buf); @@ -868,52 +827,64 @@ impl MOQTClient { } } } else { - return Err(JsValue::from_str("stream_writer is None")); + Err(JsValue::from_str("datagram_writer is None")) } } #[wasm_bindgen(js_name = sendSubgroupStreamHeaderMessage)] pub async fn send_subgroup_stream_header_message( &self, - subscribe_id: u64, track_alias: u64, group_id: u64, subgroup_id: u64, publisher_priority: u8, ) -> Result { - let mut stream_writers = self.stream_writers.borrow_mut(); + let subscribe_id = self + .subscription_node + .borrow() + .get_publishing_subscribe_id_by_track_alias(track_alias) + .unwrap(); let writer_key = (subscribe_id, Some((group_id, subgroup_id))); - if stream_writers.get(&writer_key).is_none() { - let send_uni_stream = WritableStream::from( - JsFuture::from( - self.transport - .borrow() - .as_ref() - .unwrap() - .create_unidirectional_stream(), - ) - .await?, - ); + + let writer = { + let stream_writers = self.stream_writers.borrow_mut(); + + if !stream_writers.contains_key(&writer_key) { + // 新しいwriterを作成 + let _ = { + let transport = self.transport.borrow(); + transport.as_ref().unwrap().create_unidirectional_stream() + }; + }; + stream_writers.get(&writer_key).cloned() + }; + let writer = if let Some(writer) = writer { + writer + } else { + let uni_stream_future = { + let transport = self.transport.borrow(); + transport.as_ref().unwrap().create_unidirectional_stream() + }; + let send_uni_stream = WritableStream::from(JsFuture::from(uni_stream_future).await?); let send_uni_stream_writer = send_uni_stream.get_writer()?; - stream_writers.insert(writer_key, send_uni_stream_writer); - } - let writer = stream_writers.get(&writer_key).unwrap(); - let subgroup_stream_header_message = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, - ) - .unwrap(); + // 作成したwriterを保存 + self.stream_writers + .borrow_mut() + .insert(writer_key, send_uni_stream_writer.clone()); + send_uni_stream_writer + }; + + let subgroup_stream_header_message = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); let mut subgroup_stream_header_message_buf = BytesMut::new(); - let _ = subgroup_stream_header_message.packetize(&mut subgroup_stream_header_message_buf); + subgroup_stream_header_message.packetize(&mut subgroup_stream_header_message_buf); let mut buf = Vec::new(); // Message Type buf.extend(write_variable_integer( - u8::from(DataStreamType::StreamHeaderSubgroup) as u64, + u8::from(DataStreamType::SubgroupHeader) as u64, )); buf.extend(subgroup_stream_header_message_buf); @@ -925,19 +896,34 @@ impl MOQTClient { #[wasm_bindgen(js_name = sendSubgroupStreamObject)] pub async fn send_subgroup_stream_object( &self, - subscribe_id: u64, + track_alias: u64, group_id: u64, subgroup_id: u64, object_id: u64, + object_status: Option, object_payload: Vec, ) -> Result { - let stream_writers = self.stream_writers.borrow(); + let subscribe_id = self + .subscription_node + .borrow() + .get_publishing_subscribe_id_by_track_alias(track_alias) + .unwrap(); let writer_key = (subscribe_id, Some((group_id, subgroup_id))); - if let Some(writer) = stream_writers.get(&writer_key) { - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, None, object_payload).unwrap(); + let writer = { + let stream_writers = self.stream_writers.borrow(); + stream_writers.get(&writer_key).cloned() + }; + if let Some(writer) = writer { + let extension_headers = vec![]; + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status.map(|status| ObjectStatus::try_from(status).unwrap()), + object_payload, + ) + .unwrap(); let mut subgroup_stream_object_buf = BytesMut::new(); - let _ = subgroup_stream_object.packetize(&mut subgroup_stream_object_buf); + subgroup_stream_object.packetize(&mut subgroup_stream_object_buf); let mut buf = Vec::new(); // Message Payload and Payload Length @@ -947,7 +933,15 @@ impl MOQTClient { buffer.copy_from(&buf); match JsFuture::from(writer.write_with_chunk(&buffer)).await { Ok(ok) => { - log(std::format!("sent: object id: {:#?}", object_id).as_str()); + log(std::format!( + "sent: trackAlias: {:#?} object . group_id: {:#?} subgroup_id: {:#?} object_id: {:#?} object_status: {:#?}", + track_alias, + group_id, + subgroup_id, + object_id, + object_status, + ) + .as_str()); Ok(ok) } Err(e) => { @@ -956,7 +950,7 @@ impl MOQTClient { } } } else { - return Err(JsValue::from_str("stream_writer is None")); + Err(JsValue::from_str("stream_writer is None")) } } @@ -1010,8 +1004,7 @@ impl MOQTClient { // For receiving object messages as streams let incoming_uni_stream = transport.incoming_unidirectional_streams(); - let incoming_uni_stream_reader = - ReadableStreamDefaultReader::new(&&incoming_uni_stream.into())?; + let incoming_uni_stream_reader = ReadableStreamDefaultReader::new(&incoming_uni_stream)?; let callbacks = self.callbacks.clone(); *self.stream_writers.borrow_mut() = HashMap::new(); @@ -1109,15 +1102,15 @@ async fn bi_directional_stream_read_thread( async fn control_message_handler( callbacks: Rc>, subscription_node: Rc>, - mut buf: &mut BytesMut, + buf: &mut BytesMut, ) -> Result<()> { - let message_type_value = read_variable_integer_from_buffer(&mut buf); + let message_type_value = read_variable_integer_from_buffer(buf); // TODO: Check stream type match message_type_value { Ok(v) => { let message_type = ControlMessageType::try_from(v as u8)?; - let payload_length = read_variable_integer_from_buffer(&mut buf)?; + let payload_length = read_variable_integer_from_buffer(buf)?; let mut payload_buf = buf.split_to(payload_length as usize); log(std::format!("message_type_value: {:#?}", message_type).as_str()); @@ -1125,16 +1118,11 @@ async fn control_message_handler( match message_type { ControlMessageType::ServerSetup => { let server_setup_message = ServerSetup::depacketize(&mut payload_buf)?; - log( std::format!("recv: server_setup_message: {:#x?}", server_setup_message) .as_str(), ); - if let Some(callback) = callbacks.borrow().setup_callback() { - callback - .call1(&JsValue::null(), &JsValue::from("called2")) - .unwrap(); let v = serde_wasm_bindgen::to_value(&server_setup_message).unwrap(); callback.call1(&JsValue::null(), &(v)).unwrap(); } @@ -1155,7 +1143,7 @@ async fn control_message_handler( .as_str(), ); - let _ = subscription_node + subscription_node .borrow_mut() .set_namespace(announce_ok_message.track_namespace().clone()); @@ -1217,7 +1205,7 @@ async fn control_message_handler( .as_str(), ); - let _ = subscription_node + subscription_node .borrow_mut() .activate_as_subscriber(subscribe_ok_message.subscribe_id()); @@ -1239,42 +1227,42 @@ async fn control_message_handler( callback.call1(&JsValue::null(), &(v)).unwrap(); } } - ControlMessageType::SubscribeNamespaceOk => { - let subscribe_namespace_ok_message = - SubscribeNamespaceOk::depacketize(&mut payload_buf)?; + ControlMessageType::SubscribeAnnouncesOk => { + let subscribe_announces_ok_message = + SubscribeAnnouncesOk::depacketize(&mut payload_buf)?; log(std::format!( - "recv: subscribe_namespace_ok_message: {:#x?}", - subscribe_namespace_ok_message + "recv: subscribe_announces_ok_message: {:#x?}", + subscribe_announces_ok_message ) .as_str()); - let _ = subscription_node.borrow_mut().set_namespace_prefix( - subscribe_namespace_ok_message + subscription_node.borrow_mut().set_namespace_prefix( + subscribe_announces_ok_message .track_namespace_prefix() .clone(), ); if let Some(callback) = - callbacks.borrow().subscribe_namespace_response_callback() + callbacks.borrow().subscribe_announces_response_callback() { let v = - serde_wasm_bindgen::to_value(&subscribe_namespace_ok_message).unwrap(); + serde_wasm_bindgen::to_value(&subscribe_announces_ok_message).unwrap(); callback.call1(&JsValue::null(), &(v)).unwrap(); } } - ControlMessageType::SubscribeNamespaceError => { - let subscribe_namespace_error_message = - SubscribeNamespaceError::depacketize(&mut payload_buf)?; + ControlMessageType::SubscribeAnnouncesError => { + let subscribe_announces_error_message = + SubscribeAnnouncesError::depacketize(&mut payload_buf)?; log(std::format!( - "recv: subscribe_namespace_error_message: {:#x?}", - subscribe_namespace_error_message + "recv: subscribe_announces_error_message: {:#x?}", + subscribe_announces_error_message ) .as_str()); if let Some(callback) = - callbacks.borrow().subscribe_namespace_response_callback() + callbacks.borrow().subscribe_announces_response_callback() { - let v = serde_wasm_bindgen::to_value(&subscribe_namespace_error_message) + let v = serde_wasm_bindgen::to_value(&subscribe_announces_error_message) .unwrap(); callback.call1(&JsValue::null(), &(v)).unwrap(); } @@ -1321,7 +1309,7 @@ async fn datagram_read_thread( buf.put_u8(i); } - while buf.len() > 0 { + while !buf.is_empty() { if let Err(e) = datagram_handler(callbacks.clone(), &mut buf).await { log(std::format!("error: {:#?}", e).as_str()); break; @@ -1336,69 +1324,77 @@ async fn uni_directional_stream_read_thread( callbacks: Rc>, reader: &ReadableStreamDefaultReader, ) -> Result<(), JsValue> { - use moqt_core::data_stream_type::DataStreamType; - log("uni_directional_stream_read_thread"); - let mut header_read = false; + let mut subgroup_stream_header: Option = None; let mut data_stream_type = DataStreamType::ObjectDatagram; let mut buf = BytesMut::new(); - - loop { - let ret = reader.read(); - let ret = JsFuture::from(ret).await?; - - let ret_value = js_sys::Reflect::get(&ret, &JsValue::from_str("value"))?; - let ret_done = js_sys::Reflect::get(&ret, &JsValue::from_str("done"))?; - let ret_done = js_sys::Boolean::from(ret_done).value_of(); - - if ret_done { + let mut is_end_of_stream = false; + + while !is_end_of_stream { + let ret = JsFuture::from(reader.read()).await?; + let is_done = + js_sys::Boolean::from(js_sys::Reflect::get(&ret, &JsValue::from_str("done"))?) + .value_of(); + if is_done { break; } + let value = + js_sys::Uint8Array::from(js_sys::Reflect::get(&ret, &JsValue::from_str("value"))?) + .to_vec(); - let ret_value = js_sys::Uint8Array::from(ret_value).to_vec(); - - for i in ret_value { + for i in value { buf.put_u8(i); } - while buf.len() > 0 { - if !header_read { - data_stream_type = match object_header_handler(callbacks.clone(), &mut buf).await { - Ok(v) => v, - Err(_e) => { - break; - } - }; - - header_read = true; - } else { - match data_stream_type { - DataStreamType::ObjectDatagram => { - let msg = "format error".to_string(); - log(std::format!("{}", msg).as_str()); - return Err(js_sys::Error::new(&msg).into()); - } - DataStreamType::StreamHeaderTrack => { - if let Err(e) = - track_stream_object_handler(callbacks.clone(), &mut buf).await - { - log(std::format!("error: {:#?}", e).as_str()); + while !buf.is_empty() { + if subgroup_stream_header.is_none() { + let (_data_stream_type, _subgroup_stream_header) = + match object_header_handler(callbacks.clone(), &mut buf).await { + Ok(v) => v, + Err(_e) => { break; } - } - DataStreamType::StreamHeaderSubgroup => { - if let Err(e) = - subgroup_stream_object_handler(callbacks.clone(), &mut buf).await - { - log(std::format!("error: {:#?}", e).as_str()); + }; + data_stream_type = _data_stream_type; + subgroup_stream_header = _subgroup_stream_header; + continue; + } + + match data_stream_type { + DataStreamType::ObjectDatagram | DataStreamType::ObjectDatagramStatus => { + let msg = "format error".to_string(); + log(std::format!("{:#?}", msg).as_str()); + return Err(js_sys::Error::new(&msg).into()); + } + DataStreamType::SubgroupHeader => { + match subgroup_stream_object_handler( + callbacks.clone(), + subgroup_stream_header.clone().unwrap(), + &mut buf, + ) + .await + { + Ok(object) => { + if object.object_status() == Some(ObjectStatus::EndOfGroup) { + is_end_of_stream = true; + break; + } + } + Err(_e) => { + // log(std::format!("error: {:#?}", e).as_str()); break; } } } + DataStreamType::FetchHeader => { + unimplemented!(); + } } } } + JsFuture::from(reader.cancel()).await?; + log("End of unidirectional stream"); Ok(()) } @@ -1407,35 +1403,18 @@ async fn uni_directional_stream_read_thread( async fn object_header_handler( callbacks: Rc>, buf: &mut BytesMut, -) -> Result { +) -> Result<(DataStreamType, Option)> { let mut read_cur = Cursor::new(&buf[..]); let header_type_value = read_variable_integer(&mut read_cur); - let data_stream_type = match header_type_value { + let (data_stream_type, subgroup_stream_header) = match header_type_value { Ok(v) => { let data_stream_type = DataStreamType::try_from(v as u8)?; log(std::format!("data_stream_type_value: {:#x?}", data_stream_type).as_str()); - match data_stream_type { - DataStreamType::StreamHeaderTrack => { - let track_stream_header = track_stream::Header::depacketize(&mut read_cur)?; - buf.advance(read_cur.position() as usize); - - log( - std::format!("recv: track_stream_header: {:#x?}", track_stream_header) - .as_str(), - ); - - if let Some(callback) = callbacks.borrow().track_stream_header_callback() { - callback - .call1(&JsValue::null(), &JsValue::from("called2")) - .unwrap(); - let v = serde_wasm_bindgen::to_value(&track_stream_header).unwrap(); - callback.call1(&JsValue::null(), &(v)).unwrap(); - } - } - DataStreamType::StreamHeaderSubgroup => { + let subgroup_stream_header = match data_stream_type { + DataStreamType::SubgroupHeader => { let subgroup_stream_header = subgroup_stream::Header::depacketize(&mut read_cur)?; buf.advance(read_cur.position() as usize); @@ -1446,20 +1425,19 @@ async fn object_header_handler( ); if let Some(callback) = callbacks.borrow().subgroup_stream_header_callback() { - callback - .call1(&JsValue::null(), &JsValue::from("called2")) - .unwrap(); let v = serde_wasm_bindgen::to_value(&subgroup_stream_header).unwrap(); callback.call1(&JsValue::null(), &(v)).unwrap(); } + Some(subgroup_stream_header) } _ => { // TODO: impl rest of message type log(std::format!("data_stream_type: {:#?}", data_stream_type).as_str()); + None } }; - data_stream_type + (data_stream_type, subgroup_stream_header) } Err(e) => { log("data_stream_type_value is None"); @@ -1467,11 +1445,13 @@ async fn object_header_handler( } }; - Ok(data_stream_type) + Ok((data_stream_type, subgroup_stream_header)) } #[cfg(feature = "web_sys_unstable_apis")] async fn datagram_handler(callbacks: Rc>, buf: &mut BytesMut) -> Result<()> { + use moqt_core::messages::data_streams::datagram_status; + let mut read_cur = Cursor::new(&buf[..]); let header_type_value = read_variable_integer(&mut read_cur); @@ -1481,31 +1461,49 @@ async fn datagram_handler(callbacks: Rc>, buf: &mut Bytes log(std::format!("data_stream_type_value: {:#x?}", data_stream_type).as_str()); - if data_stream_type == DataStreamType::ObjectDatagram { - let datagram_object = match datagram::Object::depacketize(&mut read_cur) { - Ok(v) => { - log(std::format!("object_id: {:#?}", v.object_id()).as_str()); - buf.advance(read_cur.position() as usize); - v - } - Err(e) => { - read_cur.set_position(0); - log(std::format!("retry because: {:#?}", e).as_str()); - return Err(e); + match data_stream_type { + DataStreamType::ObjectDatagram => { + let datagram_object = match datagram::Object::depacketize(&mut read_cur) { + Ok(v) => { + buf.advance(read_cur.position() as usize); + v + } + Err(e) => { + read_cur.set_position(0); + log(std::format!("retry because: {:#?}", e).as_str()); + return Err(e); + } + }; + + if let Some(callback) = callbacks.borrow().datagram_object_callback() { + let v = serde_wasm_bindgen::to_value(&datagram_object).unwrap(); + callback.call1(&JsValue::null(), &(v)).unwrap(); } - }; + } + DataStreamType::ObjectDatagramStatus => { + let datagram_object = match datagram_status::Object::depacketize(&mut read_cur) + { + Ok(v) => { + buf.advance(read_cur.position() as usize); + v + } + Err(e) => { + read_cur.set_position(0); + log(std::format!("retry because: {:#?}", e).as_str()); + return Err(e); + } + }; - if let Some(callback) = callbacks.borrow().datagram_object_callback() { - callback - .call1(&JsValue::null(), &JsValue::from("called2")) - .unwrap(); - let v = serde_wasm_bindgen::to_value(&datagram_object).unwrap(); - callback.call1(&JsValue::null(), &(v)).unwrap(); + if let Some(callback) = callbacks.borrow().datagram_object_status_callback() { + let v = serde_wasm_bindgen::to_value(&datagram_object).unwrap(); + callback.call1(&JsValue::null(), &(v)).unwrap(); + } + } + _ => { + let msg = "format error".to_string(); + log(std::format!("msg: {}", msg).as_str()); + return Err(anyhow::anyhow!(msg)); } - } else { - let msg = "format error".to_string(); - log(std::format!("{}", msg).as_str()); - return Err(anyhow::anyhow!(msg)); } } Err(e) => { @@ -1517,64 +1515,34 @@ async fn datagram_handler(callbacks: Rc>, buf: &mut Bytes Ok(()) } -#[cfg(feature = "web_sys_unstable_apis")] -async fn track_stream_object_handler( - callbacks: Rc>, - buf: &mut BytesMut, -) -> Result<()> { - let mut read_cur = Cursor::new(&buf[..]); - let track_stream_object = match track_stream::Object::depacketize(&mut read_cur) { - Ok(v) => { - log(std::format!("object_id: {:#?}", v.object_id()).as_str()); - buf.advance(read_cur.position() as usize); - v - } - Err(e) => { - read_cur.set_position(0); - log(std::format!("retry because: {:#?}", e).as_str()); - return Err(e); - } - }; - - if let Some(callback) = callbacks.borrow().track_stream_object_callback() { - callback - .call1(&JsValue::null(), &JsValue::from("called2")) - .unwrap(); - let v = serde_wasm_bindgen::to_value(&track_stream_object).unwrap(); - callback.call1(&JsValue::null(), &(v)).unwrap(); - } - - Ok(()) -} - #[cfg(feature = "web_sys_unstable_apis")] async fn subgroup_stream_object_handler( callbacks: Rc>, + subgroup_stream_header: subgroup_stream::Header, buf: &mut BytesMut, -) -> Result<()> { +) -> Result { let mut read_cur = Cursor::new(&buf[..]); let subgroup_stream_object = match subgroup_stream::Object::depacketize(&mut read_cur) { Ok(v) => { - log(std::format!("object_id: {:#?}", v.object_id()).as_str()); buf.advance(read_cur.position() as usize); v } Err(e) => { read_cur.set_position(0); - log(std::format!("retry because: {:#?}", e).as_str()); + // log(std::format!("retry because: {:#?}", e).as_str()); return Err(e); } }; - if let Some(callback) = callbacks.borrow().subgroup_stream_object_callback() { - callback - .call1(&JsValue::null(), &JsValue::from("called2")) - .unwrap(); + if let Some(callback) = callbacks + .borrow() + .get_subgroup_stream_object_callback(subgroup_stream_header.track_alias()) + { let v = serde_wasm_bindgen::to_value(&subgroup_stream_object).unwrap(); callback.call1(&JsValue::null(), &(v)).unwrap(); } - Ok(()) + Ok(subgroup_stream_object) } #[cfg(feature = "web_sys_unstable_apis")] @@ -1620,6 +1588,7 @@ impl SubscriptionNode { } } + #[allow(clippy::too_many_arguments)] fn set_subscribing_subscription( &mut self, subscribe_id: u64, @@ -1632,7 +1601,6 @@ impl SubscriptionNode { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) { if let Some(consumer) = &mut self.consumer { let _ = consumer.set_subscription( @@ -1646,11 +1614,11 @@ impl SubscriptionNode { start_group, start_object, end_group, - end_object, ); } } + #[allow(clippy::too_many_arguments)] fn set_publishing_subscription( &mut self, subscribe_id: u64, @@ -1663,7 +1631,6 @@ impl SubscriptionNode { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) { if let Some(producer) = &mut self.producer { let _ = producer.set_subscription( @@ -1677,7 +1644,6 @@ impl SubscriptionNode { start_group, start_object, end_group, - end_object, ); } } @@ -1736,7 +1702,6 @@ impl SubscriptionNode { subscribe_message.start_group(), subscribe_message.start_object(), subscribe_message.end_group(), - subscribe_message.end_object(), ); Ok(()) @@ -1761,6 +1726,16 @@ impl SubscriptionNode { let _ = consumer.activate_subscription(subscribe_id); } } + + fn get_publishing_subscribe_id_by_track_alias(&self, track_alias: u64) -> Option { + if let Some(producer) = &self.producer { + producer + .get_subscribe_id_by_track_alias(track_alias) + .unwrap() + } else { + None + } + } } // Due to the lifetime issue of `spawn_local`, it needs to be kept separate from MOQTClient. @@ -1772,13 +1747,12 @@ struct MOQTCallbacks { announce_responce_callback: Option, subscribe_callback: Option, subscribe_response_callback: Option, - subscribe_namespace_response_callback: Option, + subscribe_announces_response_callback: Option, unsubscribe_callback: Option, datagram_object_callback: Option, - track_stream_header_callback: Option, - track_stream_object_callback: Option, + datagram_object_status_callback: Option, subgroup_stream_header_callback: Option, - subgroup_stream_object_callback: Option, + subgroup_stream_object_callbacks: HashMap, } #[cfg(feature = "web_sys_unstable_apis")] @@ -1790,13 +1764,12 @@ impl MOQTCallbacks { announce_responce_callback: None, subscribe_callback: None, subscribe_response_callback: None, - subscribe_namespace_response_callback: None, + subscribe_announces_response_callback: None, unsubscribe_callback: None, datagram_object_callback: None, - track_stream_header_callback: None, - track_stream_object_callback: None, + datagram_object_status_callback: None, subgroup_stream_header_callback: None, - subgroup_stream_object_callback: None, + subgroup_stream_object_callbacks: HashMap::new(), } } @@ -1840,12 +1813,12 @@ impl MOQTCallbacks { self.subscribe_response_callback = Some(callback); } - pub fn subscribe_namespace_response_callback(&self) -> Option { - self.subscribe_namespace_response_callback.clone() + pub fn subscribe_announces_response_callback(&self) -> Option { + self.subscribe_announces_response_callback.clone() } - pub fn set_subscribe_namespace_response_callback(&mut self, callback: js_sys::Function) { - self.subscribe_namespace_response_callback = Some(callback); + pub fn set_subscribe_announces_response_callback(&mut self, callback: js_sys::Function) { + self.subscribe_announces_response_callback = Some(callback); } pub fn set_unsubscribe_callback(&mut self, callback: js_sys::Function) { @@ -1860,20 +1833,12 @@ impl MOQTCallbacks { self.datagram_object_callback = Some(callback); } - pub fn track_stream_header_callback(&self) -> Option { - self.track_stream_header_callback.clone() - } - - pub fn set_track_stream_header_callback(&mut self, callback: js_sys::Function) { - self.track_stream_header_callback = Some(callback); - } - - pub fn track_stream_object_callback(&self) -> Option { - self.track_stream_object_callback.clone() + pub fn datagram_object_status_callback(&self) -> Option { + self.datagram_object_status_callback.clone() } - pub fn set_track_stream_object_callback(&mut self, callback: js_sys::Function) { - self.track_stream_object_callback = Some(callback); + pub fn set_datagram_object_status_callback(&mut self, callback: js_sys::Function) { + self.datagram_object_status_callback = Some(callback); } pub fn subgroup_stream_header_callback(&self) -> Option { @@ -1884,11 +1849,20 @@ impl MOQTCallbacks { self.subgroup_stream_header_callback = Some(callback); } - pub fn subgroup_stream_object_callback(&self) -> Option { - self.subgroup_stream_object_callback.clone() + pub fn get_subgroup_stream_object_callback( + &self, + track_alias: u64, + ) -> Option<&js_sys::Function> { + let callback = self.subgroup_stream_object_callbacks.get(&track_alias); + callback } - pub fn set_subgroup_stream_object_callback(&mut self, callback: js_sys::Function) { - self.subgroup_stream_object_callback = Some(callback); + pub fn set_subgroup_stream_object_callback( + &mut self, + track_alias: u64, + callback: js_sys::Function, + ) { + self.subgroup_stream_object_callbacks + .insert(track_alias, callback); } } diff --git a/moqt-core/src/lib.rs b/moqt-core/src/lib.rs index bce21bd1..93083259 100644 --- a/moqt-core/src/lib.rs +++ b/moqt-core/src/lib.rs @@ -1,4 +1,3 @@ mod modules; pub use modules::pubsub_relation_manager_repository::PubSubRelationManagerRepository; -pub use modules::send_stream_dispatcher_repository::SendStreamDispatcherRepository; pub use modules::*; diff --git a/moqt-core/src/modules.rs b/moqt-core/src/modules.rs index 362a2597..832e1a5a 100644 --- a/moqt-core/src/modules.rs +++ b/moqt-core/src/modules.rs @@ -4,6 +4,5 @@ pub mod data_stream_type; pub mod messages; pub mod models; pub mod pubsub_relation_manager_repository; -pub mod send_stream_dispatcher_repository; pub mod variable_bytes; pub mod variable_integer; diff --git a/moqt-core/src/modules/constants.rs b/moqt-core/src/modules/constants.rs index 3fb4d9f9..f5096676 100644 --- a/moqt-core/src/modules/constants.rs +++ b/moqt-core/src/modules/constants.rs @@ -1,7 +1,7 @@ use num_enum::IntoPrimitive; -// for draft-ietf-moq-transport-06 -pub const MOQ_TRANSPORT_VERSION: u32 = 0xff000006; +// for draft-ietf-moq-transport-10 +pub const MOQ_TRANSPORT_VERSION: u32 = 0xff00000a; #[derive(Debug, IntoPrimitive, PartialEq, Clone, Copy)] #[repr(u8)] @@ -22,9 +22,3 @@ pub enum UnderlayType { WebTransport, Both, } - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum StreamDirection { - Uni, - Bi, -} diff --git a/moqt-core/src/modules/control_message_type.rs b/moqt-core/src/modules/control_message_type.rs index 8582552b..5fb2ae18 100644 --- a/moqt-core/src/modules/control_message_type.rs +++ b/moqt-core/src/modules/control_message_type.rs @@ -17,11 +17,15 @@ pub enum ControlMessageType { TrackStatusRequest = 0x0d, TrackStatus = 0x0e, GoAway = 0x10, - SubscribeNamespace = 0x11, - SubscribeNamespaceOk = 0x12, - SubscribeNamespaceError = 0x13, - UnSubscribeNamespace = 0x14, + SubscribeAnnounces = 0x11, + SubscribeAnnouncesOk = 0x12, + SubscribeAnnouncesError = 0x13, + UnSubscribeAnnounces = 0x14, MaxSubscribeId = 0x15, + Fetch = 0x16, + FetchCancel = 0x17, + FetchOk = 0x18, + FetchError = 0x19, ClientSetup = 0x40, ServerSetup = 0x41, } diff --git a/moqt-core/src/modules/data_stream_type.rs b/moqt-core/src/modules/data_stream_type.rs index 93e898d3..a11e1de8 100644 --- a/moqt-core/src/modules/data_stream_type.rs +++ b/moqt-core/src/modules/data_stream_type.rs @@ -4,6 +4,7 @@ use num_enum::{IntoPrimitive, TryFromPrimitive}; #[repr(u8)] pub enum DataStreamType { ObjectDatagram = 0x1, - StreamHeaderTrack = 0x2, - StreamHeaderSubgroup = 0x4, + ObjectDatagramStatus = 0x2, + SubgroupHeader = 0x4, + FetchHeader = 0x5, } diff --git a/moqt-core/src/modules/messages/control_messages.rs b/moqt-core/src/modules/messages/control_messages.rs index 207ad804..104dd0ad 100644 --- a/moqt-core/src/modules/messages/control_messages.rs +++ b/moqt-core/src/modules/messages/control_messages.rs @@ -3,14 +3,15 @@ pub mod announce_error; pub mod announce_ok; pub mod client_setup; pub mod go_away; +pub mod group_order; pub mod server_setup; pub mod setup_parameters; pub mod subscribe; +pub mod subscribe_announces; +pub mod subscribe_announces_error; +pub mod subscribe_announces_ok; pub mod subscribe_done; pub mod subscribe_error; -pub mod subscribe_namespace; -pub mod subscribe_namespace_error; -pub mod subscribe_namespace_ok; pub mod subscribe_ok; pub mod unannounce; pub mod unsubscribe; diff --git a/moqt-core/src/modules/messages/control_messages/client_setup.rs b/moqt-core/src/modules/messages/control_messages/client_setup.rs index 8b54bf60..1d63d1e8 100644 --- a/moqt-core/src/modules/messages/control_messages/client_setup.rs +++ b/moqt-core/src/modules/messages/control_messages/client_setup.rs @@ -82,7 +82,7 @@ mod test { messages::{ control_messages::{ client_setup::ClientSetup, - setup_parameters::{Role, RoleCase, SetupParameter}, + setup_parameters::{MaxSubscribeID, SetupParameter}, }, moqt_payload::MOQTPayload, }, @@ -92,8 +92,7 @@ mod test { #[test] fn packetize() { let supported_versions = vec![MOQ_TRANSPORT_VERSION]; - let role_parameter = Role::new(RoleCase::Subscriber); - let setup_parameters = vec![SetupParameter::Role(role_parameter.clone())]; + let setup_parameters = vec![SetupParameter::MaxSubscribeID(MaxSubscribeID::new(2000))]; let client_setup = ClientSetup::new(supported_versions, setup_parameters.clone()); let mut buf = BytesMut::new(); client_setup.packetize(&mut buf); @@ -101,11 +100,12 @@ mod test { let expected_bytes_array = [ 1, // Number of Supported Versions (i) 192, // Supported Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 2, // SETUP Parameters (..): Role(Subscriber) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff000008) in 62bit + 1, // Number of Parameters (i) + 2, // Parameter Type (i): Type(MaxSubscribeID) + 2, // Parameter Length (i) + 71, // Parameter Value (..): Length(01 of 2MSB) + 208, // Parameter Value (..): Value(2000) in 62bit ]; assert_eq!(buf.as_ref(), expected_bytes_array); @@ -116,19 +116,19 @@ mod test { let bytes_array = [ 1, // Number of Supported Versions (i) 192, // Supported Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 2, // SETUP Parameters (..): Role(Subscriber) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff000008) in 62bit + 1, // Number of Parameters (i) + 2, // Parameter Type (i): Type(MaxSubscribeID) + 2, // Parameter Length (i) + 71, // Parameter Value (..): Length(01 of 2MSB) + 208, // Parameter Value (..): Value(2000) in 62bit ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); let depacketized_client_setup = ClientSetup::depacketize(&mut buf).unwrap(); let supported_versions = vec![MOQ_TRANSPORT_VERSION]; - let role_parameter = Role::new(RoleCase::Subscriber); - let setup_parameters = vec![SetupParameter::Role(role_parameter.clone())]; + let setup_parameters = vec![SetupParameter::MaxSubscribeID(MaxSubscribeID::new(2000))]; let expected_client_setup = ClientSetup::new(supported_versions, setup_parameters.clone()); diff --git a/moqt-core/src/modules/messages/control_messages/group_order.rs b/moqt-core/src/modules/messages/control_messages/group_order.rs new file mode 100644 index 00000000..546cfce0 --- /dev/null +++ b/moqt-core/src/modules/messages/control_messages/group_order.rs @@ -0,0 +1,9 @@ +use num_enum::{IntoPrimitive, TryFromPrimitive}; +use serde::Serialize; +#[derive(Debug, Serialize, Clone, PartialEq, Eq, TryFromPrimitive, IntoPrimitive, Copy)] +#[repr(u8)] +pub enum GroupOrder { + Original = 0x0, // Use the original publisher's Group Order + Ascending = 0x1, + Descending = 0x2, +} diff --git a/moqt-core/src/modules/messages/control_messages/server_setup.rs b/moqt-core/src/modules/messages/control_messages/server_setup.rs index 8984f81e..a6f323d1 100644 --- a/moqt-core/src/modules/messages/control_messages/server_setup.rs +++ b/moqt-core/src/modules/messages/control_messages/server_setup.rs @@ -69,7 +69,7 @@ mod tests { messages::{ control_messages::{ server_setup::ServerSetup, - setup_parameters::{Role, RoleCase, SetupParameter}, + setup_parameters::{MaxSubscribeID, SetupParameter}, }, moqt_payload::MOQTPayload, }, @@ -79,19 +79,19 @@ mod tests { #[test] fn packetize() { let selected_version = MOQ_TRANSPORT_VERSION; - let role_parameter = Role::new(RoleCase::PubSub); - let setup_parameters = vec![SetupParameter::Role(role_parameter.clone())]; + let setup_parameters = vec![SetupParameter::MaxSubscribeID(MaxSubscribeID::new(2000))]; let server_setup = ServerSetup::new(selected_version, setup_parameters.clone()); let mut buf = BytesMut::new(); server_setup.packetize(&mut buf); let expected_bytes_array = [ 192, // Selected Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 3, // SETUP Parameters (..): Value(PubSub) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff000a) in 62bit + 1, // Number of Parameters (i) + 2, // Parameter Type (i): Type(MaxSubscribeID) + 2, // Parameter Length (i) + 71, // Parameter Value (..): Length(01 of 2MSB) + 208, // Parameter Value (..): Value(2000) in 62bit ]; assert_eq!(buf.as_ref(), expected_bytes_array); @@ -101,19 +101,19 @@ mod tests { fn depacketize() { let bytes_array = [ 192, // Selected Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 3, // SETUP Parameters (..): Value(PubSub) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff00000a) in 62bit + 1, // Number of Parameters (i) + 2, // Parameter Type (i): Type(MaxSubscribeID) + 2, // Parameter Length (i) + 71, // Parameter Value (..): Length(01 of 2MSB) + 208, // Parameter Value (..): Value(2000) in 62bit ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); let depacketized_server_setup = ServerSetup::depacketize(&mut buf).unwrap(); let selected_version = MOQ_TRANSPORT_VERSION; - let role_parameter = Role::new(RoleCase::PubSub); - let setup_parameters = vec![SetupParameter::Role(role_parameter.clone())]; + let setup_parameters = vec![SetupParameter::MaxSubscribeID(MaxSubscribeID::new(2000))]; let expected_server_setup = ServerSetup::new(selected_version, setup_parameters.clone()); diff --git a/moqt-core/src/modules/messages/control_messages/setup_parameters.rs b/moqt-core/src/modules/messages/control_messages/setup_parameters.rs index 2eec4742..a063f413 100644 --- a/moqt-core/src/modules/messages/control_messages/setup_parameters.rs +++ b/moqt-core/src/modules/messages/control_messages/setup_parameters.rs @@ -2,7 +2,7 @@ use crate::{ messages::moqt_payload::MOQTPayload, variable_integer::{read_variable_integer_from_buffer, write_variable_integer}, }; -use anyhow::{bail, ensure, Context, Result}; +use anyhow::{bail, Context, Result}; use bytes::BytesMut; use num_enum::{IntoPrimitive, TryFromPrimitive}; use serde::Serialize; @@ -10,7 +10,6 @@ use std::any::Any; #[derive(Debug, Serialize, Clone, PartialEq)] pub enum SetupParameter { - Role(Role), Path(Path), MaxSubscribeID(MaxSubscribeID), Unknown(u8), @@ -27,27 +26,6 @@ impl MOQTPayload for SetupParameter { } match key? { - SetupParameterType::Role => { - let length = u8::try_from(read_variable_integer_from_buffer(buf)?) - .context("role value length")?; - - // TODO: return TerminationError - ensure!( - length == 1, - "Invalid value length in ROLE parameter {:#04x}", - length - ); - - let value = RoleCase::try_from(u8::try_from( - read_variable_integer_from_buffer(buf).context("role value")?, - )?); - if let Err(err) = value { - bail!("Invalid value in ROLE parameter {:?}", err); - } - - Ok(SetupParameter::Role(Role::new(value?))) - } - // Not implemented as only WebTransport is supported now. SetupParameterType::Path => { // let value = String::from_utf8(read_variable_bytes_from_buffer(buf)?)?; @@ -55,7 +33,6 @@ impl MOQTPayload for SetupParameter { unimplemented!("Not implemented as only WebTransport is supported.") } - SetupParameterType::MaxSubscribeID => { let length = read_variable_integer_from_buffer(buf)?; let value = read_variable_integer_from_buffer(buf).context("max subscribe id")?; @@ -72,25 +49,16 @@ impl MOQTPayload for SetupParameter { fn packetize(&self, buf: &mut BytesMut) { match self { - SetupParameter::Role(param) => { - buf.extend(write_variable_integer(u8::from(param.key) as u64)); - buf.extend(write_variable_integer(param.length)); - // The value is of type varint. - buf.extend(write_variable_integer(u8::from(param.value) as u64)); - } - // Not implemented as only WebTransport is supported now. SetupParameter::Path(_param) => { unimplemented!("Not implemented as only WebTransport is supported.") } - SetupParameter::MaxSubscribeID(param) => { buf.extend(write_variable_integer(u8::from(param.key) as u64)); buf.extend(write_variable_integer(param.length)); // The value is of type varint (from MAX_SUBSCRIBE_ID message format). buf.extend(write_variable_integer(param.value)); } - SetupParameter::Unknown(_) => unimplemented!("Unknown SETUP parameter"), } } @@ -103,36 +71,10 @@ impl MOQTPayload for SetupParameter { #[derive(Debug, Clone, Copy, IntoPrimitive, TryFromPrimitive, Serialize, PartialEq)] #[repr(u8)] pub enum SetupParameterType { - Role = 0x00, Path = 0x01, MaxSubscribeID = 0x02, } -#[derive(Debug, Clone, Copy, IntoPrimitive, TryFromPrimitive, Serialize, PartialEq)] -#[repr(u8)] -pub enum RoleCase { - Publisher = 0x01, - Subscriber = 0x02, - PubSub = 0x03, -} - -#[derive(Debug, Serialize, Clone, PartialEq)] -pub struct Role { - pub key: SetupParameterType, - pub length: u64, - pub value: RoleCase, -} - -impl Role { - pub fn new(role: RoleCase) -> Self { - Role { - key: SetupParameterType::Role, - length: 0x01, - value: role, - } - } -} - #[derive(Debug, Serialize, Clone, PartialEq)] pub struct Path { pub key: SetupParameterType, @@ -174,27 +116,10 @@ mod tests { use bytes::BytesMut; use crate::messages::{ - control_messages::setup_parameters::{MaxSubscribeID, Role, RoleCase, SetupParameter}, + control_messages::setup_parameters::{MaxSubscribeID, SetupParameter}, moqt_payload::MOQTPayload, }; - #[test] - fn packetize_role() { - let role_parameter = Role::new(RoleCase::Publisher); - let setup_parameter = SetupParameter::Role(role_parameter); - - let mut buf = BytesMut::new(); - setup_parameter.packetize(&mut buf); - - let expected_bytes_array = [ - 0, // Parameter Type (i): Role - 1, // Parameter Length (i) - 1, // Parameter Value (..): Role(Publisher) - ]; - - assert_eq!(buf.as_ref(), expected_bytes_array); - } - #[test] fn packetize_max_subscribe_id() { let max_subscribe_id = MaxSubscribeID::new(2000); @@ -213,22 +138,6 @@ mod tests { assert_eq!(buf.as_ref(), expected_bytes_array); } - #[test] - fn depacketize_role() { - let bytes_array = [ - 0, // Parameter Type (i): Role - 1, // Parameter Length (i) - 2, // Parameter Value (..): Role(Subscriber) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let depacketized_setup_parameter = SetupParameter::depacketize(&mut buf).unwrap(); - - let role_parameter = Role::new(RoleCase::Subscriber); - let expected_setup_parameter = SetupParameter::Role(role_parameter); - assert_eq!(depacketized_setup_parameter, expected_setup_parameter); - } - #[test] fn depacketize_max_subscribe_id() { let bytes_array = [ @@ -284,34 +193,6 @@ mod tests { setup_parameter.packetize(&mut buf); } - #[test] - fn depacketize_role_invalid_length() { - let bytes_array = [ - 0, // Parameter Type (i): Type(Role) - 99, // Parameter Type (i): Length(Wrong) - 1, // Parameter Type (i): Role(Publisher) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let depacketized_setup_parameter = SetupParameter::depacketize(&mut buf); - - assert!(depacketized_setup_parameter.is_err()); - } - - #[test] - fn depacketize_role_invalid_value() { - let bytes_array = [ - 0, // Parameter Type (i): Type(Role) - 1, // Parameter Type (i): Length - 99, // Parameter Type (i): Role(Wrong) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let depacketized_setup_parameter = SetupParameter::depacketize(&mut buf); - - assert!(depacketized_setup_parameter.is_err()); - } - #[test] #[should_panic] fn depacketize_path() { diff --git a/moqt-core/src/modules/messages/control_messages/subscribe.rs b/moqt-core/src/modules/messages/control_messages/subscribe.rs index 67fd385e..edd85958 100644 --- a/moqt-core/src/modules/messages/control_messages/subscribe.rs +++ b/moqt-core/src/modules/messages/control_messages/subscribe.rs @@ -1,10 +1,12 @@ use crate::{ messages::{ - control_messages::version_specific_parameters::VersionSpecificParameter, + control_messages::{ + group_order::GroupOrder, version_specific_parameters::VersionSpecificParameter, + }, moqt_payload::MOQTPayload, }, variable_bytes::{ - read_fixed_length_bytes_from_buffer, read_variable_bytes_from_buffer, write_variable_bytes, + read_bytes_from_buffer, read_variable_bytes_from_buffer, write_variable_bytes, }, variable_integer::{read_variable_integer_from_buffer, write_variable_integer}, }; @@ -15,14 +17,7 @@ use serde::Serialize; use std::any::Any; use tracing; -#[derive(Debug, Serialize, Clone, PartialEq, Eq, TryFromPrimitive, IntoPrimitive, Copy)] -#[repr(u8)] -pub enum GroupOrder { - Original = 0x0, // Use the original publisher's Group Order - Ascending = 0x1, - Descending = 0x2, -} - +// TODO: Remove LatestGroup since it is not exist in the draft-10 #[derive(Debug, Serialize, Clone, PartialEq, Eq, TryFromPrimitive, IntoPrimitive, Copy)] #[repr(u8)] pub enum FilterType { @@ -44,7 +39,6 @@ pub struct Subscribe { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, number_of_parameters: u64, subscribe_parameters: Vec, } @@ -62,12 +56,11 @@ impl Subscribe { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, subscribe_parameters: Vec, ) -> anyhow::Result { - // If FilterType is LatestGroup or LatestObject, start_group/start_object/end_group/end_object must be None - // If FilterType is AbsoluteStart, start_group/start_object must be needed and end_group/end_object must be None - // If FilterType is AbsoluteRange, start_group/start_object/end_group/end_object must be needed + // If FilterType is LatestGroup or LatestObject, start_group/start_object/end_group must be None + // If FilterType is AbsoluteStart, start_group/start_object must be needed and end_group must be None + // If FilterType is AbsoluteRange, start_group/start_object/end_group must be needed match filter_type { FilterType::LatestGroup | FilterType::LatestObject => { if start_group.is_some() { @@ -76,8 +69,6 @@ impl Subscribe { bail!("start_object must be None for LatestGroup or LatestObject"); } else if end_group.is_some() { bail!("end_group must be None for LatestGroup or LatestObject"); - } else if end_object.is_some() { - bail!("end_object must be None for LatestGroup or LatestObject"); } } FilterType::AbsoluteStart => { @@ -87,8 +78,6 @@ impl Subscribe { bail!("start_object must be Some for AbsoluteStart"); } else if end_group.is_some() { bail!("end_group must be None for AbsoluteStart"); - } else if end_object.is_some() { - bail!("end_object must be None for AbsoluteStart"); } } FilterType::AbsoluteRange => { @@ -98,8 +87,6 @@ impl Subscribe { bail!("start_object must be Some for AbsoluteRange"); } else if end_group.is_none() { bail!("end_group must be Some for AbsoluteRange"); - } else if end_object.is_none() { - bail!("end_object must be Some for AbsoluteRange"); } } } @@ -116,7 +103,6 @@ impl Subscribe { start_group, start_object, end_group, - end_object, number_of_parameters, subscribe_parameters, }) @@ -162,10 +148,6 @@ impl Subscribe { self.end_group } - pub fn end_object(&self) -> Option { - self.end_object - } - pub fn subscribe_parameters(&self) -> &Vec { &self.subscribe_parameters } @@ -188,9 +170,8 @@ impl MOQTPayload for Subscribe { } let track_name = String::from_utf8(read_variable_bytes_from_buffer(buf)?).context("track name")?; - let subscriber_priority = - read_fixed_length_bytes_from_buffer(buf, 1).context("subscriber priority")?[0]; - let group_order_u8 = read_fixed_length_bytes_from_buffer(buf, 1)?[0]; + let subscriber_priority = read_bytes_from_buffer(buf, 1).context("subscriber priority")?[0]; + let group_order_u8 = read_bytes_from_buffer(buf, 1)?[0]; // Values larger than 0x2 are a Protocol Violation. let group_order = match GroupOrder::try_from(group_order_u8).context("group order") { @@ -219,12 +200,11 @@ impl MOQTPayload for Subscribe { _ => (None, None), }; - let (end_group, end_object) = match filter_type { - FilterType::AbsoluteRange => ( - Some(read_variable_integer_from_buffer(buf).context("end group")?), - Some(read_variable_integer_from_buffer(buf).context("end object")?), - ), - _ => (None, None), + let end_group = match filter_type { + FilterType::AbsoluteRange => { + Some(read_variable_integer_from_buffer(buf).context("end group")?) + } + _ => None, }; let number_of_parameters = read_variable_integer_from_buffer(buf).context("number of parameters")?; @@ -251,7 +231,6 @@ impl MOQTPayload for Subscribe { start_group, start_object, end_group, - end_object, number_of_parameters, subscribe_parameters, }) @@ -280,7 +259,6 @@ impl MOQTPayload for Subscribe { buf.extend(write_variable_integer(self.start_group.unwrap())); buf.extend(write_variable_integer(self.start_object.unwrap())); buf.extend(write_variable_integer(self.end_group.unwrap())); - buf.extend(write_variable_integer(self.end_object.unwrap())); } _ => {} } @@ -324,7 +302,6 @@ mod tests { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -341,7 +318,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -383,7 +359,6 @@ mod tests { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -400,7 +375,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -444,7 +418,6 @@ mod tests { let start_group = Some(0); let start_object = Some(0); let end_group = Some(10); - let end_object = Some(100); let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -461,7 +434,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -486,7 +458,6 @@ mod tests { 0, // Start Group (i) 0, // Start Object (i) 10, // End Group (i) - 64, 100, // End Object (i) 1, // Track Request Parameters (..): Number of Parameters 2, // Parameter Type (i): AuthorizationInfo 4, // Parameter Length @@ -530,7 +501,6 @@ mod tests { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -546,7 +516,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -591,7 +560,6 @@ mod tests { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -607,7 +575,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -634,7 +601,6 @@ mod tests { 0, // Start Group (i) 0, // Start Object (i) 10, // End Group (i) - 64, 100, // End Object (i) 1, // Track Request Parameters (..): Number of Parameters 2, // Parameter Type (i): AuthorizationInfo 4, // Parameter Length @@ -654,7 +620,6 @@ mod tests { let start_group = Some(0); let start_object = Some(0); let end_group = Some(10); - let end_object = Some(100); let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -670,7 +635,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -701,7 +665,6 @@ mod tests { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -718,7 +681,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ); @@ -737,7 +699,6 @@ mod tests { let start_group = None; let start_object = None; let end_group = Some(1); - let end_object = Some(1); let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -754,7 +715,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ); @@ -773,7 +733,6 @@ mod tests { let start_group = Some(0); let start_object = Some(0); let end_group = Some(1); - let end_object = Some(1); let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); @@ -790,7 +749,6 @@ mod tests { start_group, start_object, end_group, - end_object, subscribe_parameters, ); diff --git a/moqt-core/src/modules/messages/control_messages/subscribe_namespace.rs b/moqt-core/src/modules/messages/control_messages/subscribe_announces.rs similarity index 90% rename from moqt-core/src/modules/messages/control_messages/subscribe_namespace.rs rename to moqt-core/src/modules/messages/control_messages/subscribe_announces.rs index 2ebbc586..1045582c 100644 --- a/moqt-core/src/modules/messages/control_messages/subscribe_namespace.rs +++ b/moqt-core/src/modules/messages/control_messages/subscribe_announces.rs @@ -12,19 +12,19 @@ use serde::Serialize; use std::any::Any; #[derive(Debug, Serialize, Clone, PartialEq)] -pub struct SubscribeNamespace { +pub struct SubscribeAnnounces { track_namespace_prefix: Vec, number_of_parameters: u64, parameters: Vec, } -impl SubscribeNamespace { +impl SubscribeAnnounces { pub fn new( track_namespace_prefix: Vec, parameters: Vec, ) -> Self { let number_of_parameters = parameters.len() as u64; - SubscribeNamespace { + SubscribeAnnounces { track_namespace_prefix, number_of_parameters, parameters, @@ -40,7 +40,7 @@ impl SubscribeNamespace { } } -impl MOQTPayload for SubscribeNamespace { +impl MOQTPayload for SubscribeAnnounces { fn depacketize(buf: &mut BytesMut) -> Result { let track_namespace_prefix_tuple_length = u8::try_from(read_variable_integer_from_buffer(buf)?) @@ -65,7 +65,7 @@ impl MOQTPayload for SubscribeNamespace { } } - Ok(SubscribeNamespace { + Ok(SubscribeAnnounces { track_namespace_prefix: track_namespace_prefix_tuple, number_of_parameters, parameters, @@ -87,7 +87,7 @@ impl MOQTPayload for SubscribeNamespace { parameter.packetize(buf); } } - /// Method to enable downcasting from MOQTPayload to SubscribeNamespace + /// Method to enable downcasting from MOQTPayload to SubscribeAnnounces fn as_any(&self) -> &dyn Any { self } @@ -98,7 +98,7 @@ mod tests { mod success { use crate::messages::{ control_messages::{ - subscribe_namespace::SubscribeNamespace, + subscribe_announces::SubscribeAnnounces, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, moqt_payload::MOQTPayload, @@ -112,10 +112,10 @@ mod tests { AuthorizationInfo::new("test".to_string()), ); let parameters = vec![version_specific_parameter]; - let subscribe_namespace = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace.packetize(&mut buf); + subscribe_announces.packetize(&mut buf); let expected_bytes_array = [ 2, // Track Namespace Prefix(tuple): Number of elements @@ -146,17 +146,17 @@ mod tests { ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); - let subscribe_namespace = SubscribeNamespace::depacketize(&mut buf).unwrap(); + let subscribe_announces = SubscribeAnnounces::depacketize(&mut buf).unwrap(); let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo( AuthorizationInfo::new("test".to_string()), ); let parameters = vec![version_specific_parameter]; - let expected_subscribe_namespace = - SubscribeNamespace::new(track_namespace_prefix, parameters); + let expected_subscribe_announces = + SubscribeAnnounces::new(track_namespace_prefix, parameters); - assert_eq!(subscribe_namespace, expected_subscribe_namespace); + assert_eq!(subscribe_announces, expected_subscribe_announces); } } } diff --git a/moqt-core/src/modules/messages/control_messages/subscribe_namespace_error.rs b/moqt-core/src/modules/messages/control_messages/subscribe_announces_error.rs similarity index 79% rename from moqt-core/src/modules/messages/control_messages/subscribe_namespace_error.rs rename to moqt-core/src/modules/messages/control_messages/subscribe_announces_error.rs index f5662d00..4d241521 100644 --- a/moqt-core/src/modules/messages/control_messages/subscribe_namespace_error.rs +++ b/moqt-core/src/modules/messages/control_messages/subscribe_announces_error.rs @@ -9,19 +9,19 @@ use serde::Serialize; use std::any::Any; #[derive(Debug, Serialize, Clone, PartialEq)] -pub struct SubscribeNamespaceError { +pub struct SubscribeAnnouncesError { track_namespace_prefix: Vec, error_code: u64, reason_phrase: String, } -impl SubscribeNamespaceError { +impl SubscribeAnnouncesError { pub fn new( track_namespace_prefix: Vec, error_code: u64, reason_phrase: String, ) -> Self { - SubscribeNamespaceError { + SubscribeAnnouncesError { track_namespace_prefix, error_code, reason_phrase, @@ -41,7 +41,7 @@ impl SubscribeNamespaceError { } } -impl MOQTPayload for SubscribeNamespaceError { +impl MOQTPayload for SubscribeAnnouncesError { fn depacketize(buf: &mut BytesMut) -> Result { let track_namespace_prefix_tuple_length = u8::try_from(read_variable_integer_from_buffer(buf)?) @@ -56,7 +56,7 @@ impl MOQTPayload for SubscribeNamespaceError { let reason_phrase = String::from_utf8(read_variable_bytes_from_buffer(buf)?).context("reason phrase")?; - Ok(SubscribeNamespaceError { + Ok(SubscribeAnnouncesError { track_namespace_prefix: track_namespace_prefix_tuple, error_code, reason_phrase, @@ -78,7 +78,7 @@ impl MOQTPayload for SubscribeNamespaceError { &self.reason_phrase.as_bytes().to_vec(), )); } - /// Method to enable downcasting from MOQTPayload to SubscribeNamespaceError + /// Method to enable downcasting from MOQTPayload to SubscribeAnnouncesError fn as_any(&self) -> &dyn Any { self } @@ -88,7 +88,7 @@ impl MOQTPayload for SubscribeNamespaceError { mod tests { mod success { use crate::messages::{ - control_messages::subscribe_namespace_error::SubscribeNamespaceError, + control_messages::subscribe_announces_error::SubscribeAnnouncesError, moqt_payload::MOQTPayload, }; use bytes::BytesMut; @@ -97,14 +97,14 @@ mod tests { fn packetize() { let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); let error_code: u64 = 1; - let reason_phrase = "subscribe namespace overlap".to_string(); - let subscribe_namespace_error = SubscribeNamespaceError::new( + let reason_phrase = "subscribe announces overlap".to_string(); + let subscribe_announces_error = SubscribeAnnouncesError::new( track_namespace_prefix.clone(), error_code, reason_phrase.clone(), ); let mut buf = BytesMut::new(); - subscribe_namespace_error.packetize(&mut buf); + subscribe_announces_error.packetize(&mut buf); let expected_bytes_array = [ 2, // Track Namespace Prefix(tuple): Number of elements @@ -114,9 +114,9 @@ mod tests { 116, 101, 115, 116, // Track Namespace Prefix(b): Value("test") 1, // Error Code (i) 27, // Reason Phrase (b): length - 115, 117, 98, 115, 99, 114, 105, 98, 101, 32, 110, 97, 109, 101, 115, 112, 97, 99, - 101, 32, 111, 118, 101, 114, 108, 97, - 112, // Reason Phrase (b): Value("subscribe namespace overlap") + 115, 117, 98, 115, 99, 114, 105, 98, 101, 32, 97, 110, 110, 111, 117, 110, 99, 101, + 115, 32, 111, 118, 101, 114, 108, 97, + 112, // Reason Phrase (b): Value("subscribe announces overlap") ]; assert_eq!(buf.as_ref(), expected_bytes_array.as_slice()); } @@ -131,26 +131,26 @@ mod tests { 116, 101, 115, 116, // Track Namespace Prefix(b): Value("test") 1, // Error Code (i) 27, // Reason Phrase (b): length - 115, 117, 98, 115, 99, 114, 105, 98, 101, 32, 110, 97, 109, 101, 115, 112, 97, 99, - 101, 32, 111, 118, 101, 114, 108, 97, - 112, // Reason Phrase (b): Value("subscribe namespace overlap") + 115, 117, 98, 115, 99, 114, 105, 98, 101, 32, 97, 110, 110, 111, 117, 110, 99, 101, + 115, 32, 111, 118, 101, 114, 108, 97, + 112, // Reason Phrase (b): Value("subscribe announces overlap") ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); - let subscribe_namespace_error = SubscribeNamespaceError::depacketize(&mut buf).unwrap(); + let subscribe_announces_error = SubscribeAnnouncesError::depacketize(&mut buf).unwrap(); let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); let error_code: u64 = 1; - let reason_phrase = "subscribe namespace overlap".to_string(); - let expected_subscribe_namespace_error = SubscribeNamespaceError::new( + let reason_phrase = "subscribe announces overlap".to_string(); + let expected_subscribe_announces_error = SubscribeAnnouncesError::new( track_namespace_prefix.clone(), error_code, reason_phrase.clone(), ); assert_eq!( - subscribe_namespace_error, - expected_subscribe_namespace_error + subscribe_announces_error, + expected_subscribe_announces_error ); } } diff --git a/moqt-core/src/modules/messages/control_messages/subscribe_namespace_ok.rs b/moqt-core/src/modules/messages/control_messages/subscribe_announces_ok.rs similarity index 84% rename from moqt-core/src/modules/messages/control_messages/subscribe_namespace_ok.rs rename to moqt-core/src/modules/messages/control_messages/subscribe_announces_ok.rs index e011a202..788546bf 100644 --- a/moqt-core/src/modules/messages/control_messages/subscribe_namespace_ok.rs +++ b/moqt-core/src/modules/messages/control_messages/subscribe_announces_ok.rs @@ -9,13 +9,13 @@ use serde::Serialize; use std::any::Any; #[derive(Debug, Serialize, Clone, PartialEq)] -pub struct SubscribeNamespaceOk { +pub struct SubscribeAnnouncesOk { track_namespace_prefix: Vec, } -impl SubscribeNamespaceOk { +impl SubscribeAnnouncesOk { pub fn new(track_namespace_prefix: Vec) -> Self { - SubscribeNamespaceOk { + SubscribeAnnouncesOk { track_namespace_prefix, } } @@ -25,7 +25,7 @@ impl SubscribeNamespaceOk { } } -impl MOQTPayload for SubscribeNamespaceOk { +impl MOQTPayload for SubscribeAnnouncesOk { fn depacketize(buf: &mut BytesMut) -> Result { let track_namespace_prefix_tuple_length = u8::try_from(read_variable_integer_from_buffer(buf)?) @@ -37,7 +37,7 @@ impl MOQTPayload for SubscribeNamespaceOk { track_namespace_prefix_tuple.push(track_namespace_prefix); } - Ok(SubscribeNamespaceOk { + Ok(SubscribeAnnouncesOk { track_namespace_prefix: track_namespace_prefix_tuple, }) } @@ -53,7 +53,7 @@ impl MOQTPayload for SubscribeNamespaceOk { )); } } - /// Method to enable downcasting from MOQTPayload to SubscribeNamespaceOk + /// Method to enable downcasting from MOQTPayload to SubscribeAnnouncesOk fn as_any(&self) -> &dyn Any { self } @@ -63,7 +63,7 @@ impl MOQTPayload for SubscribeNamespaceOk { mod tests { mod success { use crate::messages::{ - control_messages::subscribe_namespace_ok::SubscribeNamespaceOk, + control_messages::subscribe_announces_ok::SubscribeAnnouncesOk, moqt_payload::MOQTPayload, }; use bytes::BytesMut; @@ -71,9 +71,9 @@ mod tests { #[test] fn packetize() { let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); - let subscribe_namespace_ok = SubscribeNamespaceOk::new(track_namespace_prefix.clone()); + let subscribe_announces_ok = SubscribeAnnouncesOk::new(track_namespace_prefix.clone()); let mut buf = BytesMut::new(); - subscribe_namespace_ok.packetize(&mut buf); + subscribe_announces_ok.packetize(&mut buf); let expected_bytes_array = [ 2, // Track Namespace Prefix(tuple): Number of elements @@ -96,12 +96,12 @@ mod tests { ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); - let subscribe_namespace_ok = SubscribeNamespaceOk::depacketize(&mut buf).unwrap(); + let subscribe_announces_ok = SubscribeAnnouncesOk::depacketize(&mut buf).unwrap(); let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); - let expected_subscribe_namespace_ok = SubscribeNamespaceOk::new(track_namespace_prefix); + let expected_subscribe_announces_ok = SubscribeAnnouncesOk::new(track_namespace_prefix); - assert_eq!(subscribe_namespace_ok, expected_subscribe_namespace_ok); + assert_eq!(subscribe_announces_ok, expected_subscribe_announces_ok); } } } diff --git a/moqt-core/src/modules/messages/control_messages/subscribe_done.rs b/moqt-core/src/modules/messages/control_messages/subscribe_done.rs index aeb0340b..d4fd949d 100644 --- a/moqt-core/src/modules/messages/control_messages/subscribe_done.rs +++ b/moqt-core/src/modules/messages/control_messages/subscribe_done.rs @@ -1,7 +1,7 @@ use crate::{ messages::moqt_payload::MOQTPayload, variable_bytes::{ - read_fixed_length_bytes_from_buffer, read_variable_bytes_from_buffer, write_variable_bytes, + read_bytes_from_buffer, read_variable_bytes_from_buffer, write_variable_bytes, }, variable_integer::{read_variable_integer_from_buffer, write_variable_integer}, }; @@ -60,15 +60,14 @@ impl MOQTPayload for SubscribeDone { let status_code = StatusCode::try_from(status_code_u64).context("status code")?; let reason_phrase = String::from_utf8(read_variable_bytes_from_buffer(buf)?).context("reason phrase")?; - let content_exists = - match read_fixed_length_bytes_from_buffer(buf, 1).context("content_exists")?[0] { - 0 => false, - 1 => true, - _ => { - // TODO: return Termination Error Code - bail!("Invalid content_exists value: Protocol Violation"); - } - }; + let content_exists = match read_bytes_from_buffer(buf, 1).context("content_exists")?[0] { + 0 => false, + 1 => true, + _ => { + // TODO: return Termination Error Code + bail!("Invalid content_exists value: Protocol Violation"); + } + }; let (final_group_id, final_object_id) = match content_exists { true => { diff --git a/moqt-core/src/modules/messages/control_messages/subscribe_ok.rs b/moqt-core/src/modules/messages/control_messages/subscribe_ok.rs index cca3df67..b23f633d 100644 --- a/moqt-core/src/modules/messages/control_messages/subscribe_ok.rs +++ b/moqt-core/src/modules/messages/control_messages/subscribe_ok.rs @@ -1,11 +1,11 @@ use crate::{ messages::{ control_messages::{ - subscribe::GroupOrder, version_specific_parameters::VersionSpecificParameter, + group_order::GroupOrder, version_specific_parameters::VersionSpecificParameter, }, moqt_payload::MOQTPayload, }, - variable_bytes::read_fixed_length_bytes_from_buffer, + variable_bytes::read_bytes_from_buffer, variable_integer::{read_variable_integer_from_buffer, write_variable_integer}, }; use anyhow::bail; @@ -85,7 +85,7 @@ impl MOQTPayload for SubscribeOk { { let subscribe_id = read_variable_integer_from_buffer(buf).context("subscribe_id")?; let expires = read_variable_integer_from_buffer(buf).context("expires")?; - let group_order_u8 = read_fixed_length_bytes_from_buffer(buf, 1)?[0]; + let group_order_u8 = read_bytes_from_buffer(buf, 1)?[0]; // Values larger than 0x2 are a Protocol Violation. let group_order = match GroupOrder::try_from(group_order_u8).context("group order") { @@ -96,15 +96,14 @@ impl MOQTPayload for SubscribeOk { } }; - let content_exists = - match read_fixed_length_bytes_from_buffer(buf, 1).context("content_exists")?[0] { - 0 => false, - 1 => true, - _ => { - // TODO: return Termination Error Code - bail!("Invalid content_exists value: Protocol Violation"); - } - }; + let content_exists = match read_bytes_from_buffer(buf, 1).context("content_exists")?[0] { + 0 => false, + 1 => true, + _ => { + // TODO: return Termination Error Code + bail!("Invalid content_exists value: Protocol Violation"); + } + }; let (largest_group_id, largest_object_id) = if content_exists { let largest_group_id = diff --git a/moqt-core/src/modules/messages/control_messages/version_specific_parameters.rs b/moqt-core/src/modules/messages/control_messages/version_specific_parameters.rs index 73b4d0cf..c67ad3d7 100644 --- a/moqt-core/src/modules/messages/control_messages/version_specific_parameters.rs +++ b/moqt-core/src/modules/messages/control_messages/version_specific_parameters.rs @@ -1,10 +1,8 @@ use crate::{ messages::moqt_payload::MOQTPayload, - variable_bytes::{ - convert_bytes_to_integer, read_fixed_length_bytes_from_buffer, write_fixed_length_bytes, - }, + variable_bytes::{bytes_to_integer, read_bytes_from_buffer, write_bytes}, variable_integer::{ - get_length_from_variable_integer_first_byte, read_variable_integer_from_buffer, + get_2msb_length_from_first_byte, get_2msb_value, read_variable_integer_from_buffer, write_variable_integer, }, }; @@ -32,7 +30,7 @@ impl MOQTPayload for VersionSpecificParameter { read_variable_integer_from_buffer(buf)?, )?); let parameter_length = read_variable_integer_from_buffer(buf)?; - let parameter_value = read_fixed_length_bytes_from_buffer(buf, parameter_length as usize)?; + let parameter_value = read_bytes_from_buffer(buf, parameter_length as usize)?; if let Err(err) = parameter_type { // If it appears in some other type of message, it MUST be ignored. @@ -53,7 +51,7 @@ impl MOQTPayload for VersionSpecificParameter { } VersionSpecificParameterType::DeliveryTimeout => { // The value is of type varint. - let parameter_value: u64 = convert_bytes_to_integer(parameter_value)?; + let parameter_value: u64 = bytes_to_integer(parameter_value)?; Ok(VersionSpecificParameter::DeliveryTimeout( DeliveryTimeout::new(parameter_value), @@ -61,7 +59,7 @@ impl MOQTPayload for VersionSpecificParameter { } VersionSpecificParameterType::MaxCacheDuration => { // The value is of type varint. - let parameter_value: u64 = convert_bytes_to_integer(parameter_value)?; + let parameter_value: u64 = bytes_to_integer(parameter_value)?; Ok(VersionSpecificParameter::MaxCacheDuration( MaxCacheDuration::new(parameter_value), @@ -76,7 +74,7 @@ impl MOQTPayload for VersionSpecificParameter { buf.extend(write_variable_integer(u64::from(param.parameter_type))); buf.extend(write_variable_integer(param.length as u64)); // The value is an ASCII string. - buf.extend(write_fixed_length_bytes(¶m.value.as_bytes().to_vec())); + buf.extend(write_bytes(¶m.value.as_bytes().to_vec())); } VersionSpecificParameter::DeliveryTimeout(param) => { buf.extend(write_variable_integer(u64::from(param.parameter_type))); @@ -142,12 +140,13 @@ pub struct DeliveryTimeout { impl DeliveryTimeout { pub fn new(value: u64) -> Self { let first_byte = (value & 0xFF) as u8; // 0xFF: Bit mask to get the first byte - let length = get_length_from_variable_integer_first_byte(first_byte); + let length = get_2msb_length_from_first_byte(first_byte); + let first_two_bits_masked_value = get_2msb_value(value); DeliveryTimeout { parameter_type: VersionSpecificParameterType::DeliveryTimeout, length, - value, + value: first_two_bits_masked_value, } } } @@ -162,12 +161,13 @@ pub struct MaxCacheDuration { impl MaxCacheDuration { pub fn new(value: u64) -> Self { let first_byte = (value & 0xFF) as u8; // 0xFF: Bit mask to get the first byte - let length = get_length_from_variable_integer_first_byte(first_byte); + let length = get_2msb_length_from_first_byte(first_byte); + let first_two_bits_masked_value = get_2msb_value(value); MaxCacheDuration { parameter_type: VersionSpecificParameterType::MaxCacheDuration, length, - value, + value: first_two_bits_masked_value, } } } diff --git a/moqt-core/src/modules/messages/data_streams.rs b/moqt-core/src/modules/messages/data_streams.rs index 209db364..319e08e7 100644 --- a/moqt-core/src/modules/messages/data_streams.rs +++ b/moqt-core/src/modules/messages/data_streams.rs @@ -1,7 +1,8 @@ pub mod datagram; +pub mod datagram_status; +pub mod extension_header; pub mod object_status; pub mod subgroup_stream; -pub mod track_stream; use anyhow::Result; use bytes::BytesMut; @@ -12,3 +13,9 @@ pub trait DataStreams: Send + Sync { Self: Sized; fn packetize(&self, buf: &mut BytesMut); } + +#[derive(Debug, PartialEq, Clone)] +pub enum DatagramObject { + ObjectDatagram(datagram::Object), + ObjectDatagramStatus(datagram_status::Object), +} diff --git a/moqt-core/src/modules/messages/data_streams/datagram.rs b/moqt-core/src/modules/messages/data_streams/datagram.rs index 5b19fec9..f2531cfb 100644 --- a/moqt-core/src/modules/messages/data_streams/datagram.rs +++ b/moqt-core/src/modules/messages/data_streams/datagram.rs @@ -1,66 +1,52 @@ +use super::extension_header::ExtensionHeader; use crate::{ - messages::data_streams::{object_status::ObjectStatus, DataStreams}, - variable_bytes::read_fixed_length_bytes, + messages::data_streams::DataStreams, + variable_bytes::{read_all_variable_bytes, read_bytes}, variable_integer::{read_variable_integer, write_variable_integer}, }; -use anyhow::{bail, Context, Result}; -use bytes::BytesMut; +use anyhow::{Context, Result}; +use bytes::{Buf, BytesMut}; use serde::Serialize; /// Implementation of object message per QUIC Datagram. /// Type of Data Streams: OBJECT_DATAGRAM (0x1) #[derive(Debug, Clone, Serialize, PartialEq)] pub struct Object { - subscribe_id: u64, track_alias: u64, group_id: u64, object_id: u64, publisher_priority: u8, - object_payload_length: u64, - object_status: Option, + extension_headers_length: u64, + extension_headers: Vec, object_payload: Vec, } impl Object { pub fn new( - subscribe_id: u64, track_alias: u64, group_id: u64, object_id: u64, publisher_priority: u8, - object_status: Option, + extension_headers: Vec, object_payload: Vec, ) -> Result { - let object_payload_length = object_payload.len() as u64; - - if object_status.is_some() && object_payload_length != 0 { - bail!("The Object Status field is only sent if the Object Payload Length is zero."); - } - - // Any object with a status code other than zero MUST have an empty payload. - if let Some(status) = object_status { - if status != ObjectStatus::Normal && object_payload_length != 0 { - // TODO: return Termination Error Code - bail!("Any object with a status code other than zero MUST have an empty payload."); - } + // length of total byte of extension headers + let mut extension_headers_length = 0; + for header in &extension_headers { + extension_headers_length += header.byte_length() as u64; } Ok(Object { - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_payload_length, - object_status, + extension_headers_length, + extension_headers, object_payload, }) } - pub fn subscribe_id(&self) -> u64 { - self.subscribe_id - } - pub fn track_alias(&self) -> u64 { self.track_alias } @@ -77,8 +63,8 @@ impl Object { self.publisher_priority } - pub fn object_status(&self) -> Option { - self.object_status + pub fn extension_headers(&self) -> &Vec { + &self.extension_headers } pub fn object_payload(&self) -> Vec { @@ -91,67 +77,51 @@ impl DataStreams for Object { where Self: Sized, { - let subscribe_id = read_variable_integer(read_cur).context("subscribe id")?; let track_alias = read_variable_integer(read_cur).context("track alias")?; let group_id = read_variable_integer(read_cur).context("group id")?; let object_id = read_variable_integer(read_cur).context("object id")?; - let publisher_priority = - read_fixed_length_bytes(read_cur, 1).context("publisher priority")?[0]; - let object_payload_length = - read_variable_integer(read_cur).context("object payload length")?; - - // If the length of the remaining buf is larger than object_payload_length, object_status exists. - // The Object Status field is only sent if the Object Payload Length is zero. - let object_status = if object_payload_length == 0 { - let object_status_u64 = read_variable_integer(read_cur)?; - let object_status = - match ObjectStatus::try_from(object_status_u64 as u8).context("object status") { - Ok(status) => status, - Err(err) => { - // Any other value SHOULD be treated as a Protocol Violation and terminate the session with a Protocol Violation - // TODO: return Termination Error Code - bail!(err); - } - }; - - Some(object_status) - } else { - None - }; + let publisher_priority = read_bytes(read_cur, 1).context("publisher priority")?[0]; - let object_payload = if object_payload_length > 0 { - read_fixed_length_bytes(read_cur, object_payload_length as usize) - .context("object payload")? - } else { - vec![] - }; + let extension_headers_length = + read_variable_integer(read_cur).context("extension headers length")?; + + let mut extension_headers_vec = vec![]; + let extension_headers = + read_bytes(read_cur, extension_headers_length as usize).context("extension headers")?; + let mut extension_headers_cur = std::io::Cursor::new(&extension_headers[..]); + + while extension_headers_cur.has_remaining() { + let extension_header = ExtensionHeader::depacketize(&mut extension_headers_cur) + .context("extension header")?; + extension_headers_vec.push(extension_header); + } + + let object_payload = read_all_variable_bytes(read_cur).context("object payload")?; tracing::trace!("Depacketized Datagram Object message."); Ok(Object { - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_payload_length, - object_status, + extension_headers_length, + extension_headers: extension_headers_vec, object_payload, }) } fn packetize(&self, buf: &mut BytesMut) { - buf.extend(write_variable_integer(self.subscribe_id)); buf.extend(write_variable_integer(self.track_alias)); buf.extend(write_variable_integer(self.group_id)); buf.extend(write_variable_integer(self.object_id)); buf.extend(self.publisher_priority.to_be_bytes()); - buf.extend(write_variable_integer(self.object_payload_length)); - if self.object_status.is_some() { - buf.extend(write_variable_integer( - u8::from(self.object_status.unwrap()) as u64, - )); + + buf.extend(write_variable_integer(self.extension_headers_length)); + for header in &self.extension_headers { + header.packetize(buf); } + buf.extend(&self.object_payload); tracing::trace!("Packetized Datagram Object message."); @@ -161,27 +131,29 @@ impl DataStreams for Object { #[cfg(test)] mod tests { mod success { - use crate::messages::data_streams::{datagram, object_status::ObjectStatus, DataStreams}; + use crate::messages::data_streams::{ + datagram, + extension_header::{ExtensionHeader, ExtensionHeaderValue, Value, ValueWithLength}, + DataStreams, + }; use bytes::BytesMut; use std::io::Cursor; #[test] fn packetize_datagram_object_normal() { - let subscribe_id = 0; let track_alias = 1; let group_id = 2; let object_id = 3; let publisher_priority = 4; - let object_status = None; + let extension_headers = vec![]; let object_payload = vec![0, 1, 2]; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); @@ -190,12 +162,11 @@ mod tests { datagram_object.packetize(&mut buf); let expected_bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 3, // Object Payload Length (i) + 0, // Extension Headers Length (i) 0, 1, 2, // Object Payload (..) ]; @@ -203,22 +174,60 @@ mod tests { } #[test] - fn packetize_datagram_object_normal_and_empty_payload() { - let subscribe_id = 0; + fn depacketize_datagram_object_normal() { + let bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 0, // Extension Headers Length (i) + 0, 1, 2, // Object Payload (..) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_datagram_object = + datagram::Object::depacketize(&mut read_cur).unwrap(); + + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let extension_headers = vec![]; + let object_payload = vec![0, 1, 2]; + + let expected_datagram_object = datagram::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_payload, + ) + .unwrap(); + + assert_eq!(depacketized_datagram_object, expected_datagram_object); + } + + #[test] + fn packetize_datagram_stream_object_with_even_type_extension_header() { let track_alias = 1; let group_id = 2; let object_id = 3; let publisher_priority = 4; - let object_status = Some(ObjectStatus::Normal); - let object_payload = vec![]; + let header_type = 4; + let value = 1; + let header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_payload = vec![0, 1, 2]; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); @@ -227,35 +236,38 @@ mod tests { datagram_object.packetize(&mut buf); let expected_bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 0, // Object Payload Length (i) - 0, // Object Status (i) + 2, // Extension Headers Length (i) + 4, // Header Type (i) + 1, // Header Value (i) + 0, 1, 2, // Object Payload (..) ]; assert_eq!(buf.as_ref(), expected_bytes_array); } #[test] - fn packetize_datagram_object_not_normal() { - let subscribe_id = 0; + fn packetize_datagram_stream_object_with_odd_type_extension_header() { let track_alias = 1; let group_id = 2; let object_id = 3; let publisher_priority = 4; - let object_status = Some(ObjectStatus::EndOfGroup); - let object_payload = vec![]; + let header_type = 1; + let value = vec![1, 2, 3]; + let header_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_payload = vec![0, 1, 2]; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); @@ -264,27 +276,83 @@ mod tests { datagram_object.packetize(&mut buf); let expected_bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 0, // Object Payload Length (i) - 3, // Object Status (i) + 5, // Extension Headers Length (i) + 1, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + 0, 1, 2, // Object Payload (..) ]; assert_eq!(buf.as_ref(), expected_bytes_array); } #[test] - fn depacketize_datagram_object_normal() { + fn packetize_datagram_stream_object_with_mixed_type_extension_headers() { + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let even_header_type = 12; + let even_value = 1; + let even_header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(even_value)); + let odd_header_type = 9; + let odd_value = vec![1, 2, 3]; + let odd_header_value = + ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(odd_value)); + + let extension_headers = vec![ + ExtensionHeader::new(odd_header_type, odd_header_value).unwrap(), + ExtensionHeader::new(even_header_type, even_header_value).unwrap(), + ]; + let object_payload = vec![0, 1, 2]; + + let datagram_object = datagram::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_payload, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + datagram_object.packetize(&mut buf); + + let expected_bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 7, // Extension Headers Length (i) + //{ + 9, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + // }{ + 12, // Header Type (i) + 1, // Header Value (i) + // } + 0, 1, 2, // Object Payload (..) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn depacketize_datagram_stream_object_with_even_type_extension_header() { let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 3, // Object Payload Length (i) + 2, // Extension Headers Length (i) + 4, // Header Type (i) + 1, // Header Value (i) 0, 1, 2, // Object Payload (..) ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); @@ -293,21 +361,23 @@ mod tests { let depacketized_datagram_object = datagram::Object::depacketize(&mut read_cur).unwrap(); - let subscribe_id = 0; let track_alias = 1; let group_id = 2; let object_id = 3; let publisher_priority = 4; - let object_status = None; + let header_type = 4; + let value = 1; + let header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; let object_payload = vec![0, 1, 2]; let expected_datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); @@ -316,15 +386,17 @@ mod tests { } #[test] - fn depacketize_datagram_object_normal_and_empty_payload() { + fn depacketize_datagram_stream_object_with_odd_type_extension_header() { let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 0, // Object Payload Length (i) - 0, // Object Status (i) + 5, // Extension Headers Length (i) + 1, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + 0, 1, 2, // Object Payload (..) ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); @@ -332,21 +404,23 @@ mod tests { let depacketized_datagram_object = datagram::Object::depacketize(&mut read_cur).unwrap(); - let subscribe_id = 0; let track_alias = 1; let group_id = 2; let object_id = 3; let publisher_priority = 4; - let object_status = Some(ObjectStatus::Normal); - let object_payload = vec![]; + let header_type = 1; + let value = vec![1, 2, 3]; + let header_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_payload = vec![0, 1, 2]; let expected_datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); @@ -355,15 +429,22 @@ mod tests { } #[test] - fn depacketize_datagram_object_not_normal() { + fn depacketize_datagram_stream_object_with_mixed_type_extension_headers() { let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 0, // Object Payload Length (i) - 1, // Object Status (i) + 7, // Extension Headers Length (i) + //{ + 9, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + // }{ + 12, // Header Type (i) + 1, // Header Value (i) + // } + 0, 1, 2, // Object Payload (..) ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); @@ -371,21 +452,30 @@ mod tests { let depacketized_datagram_object = datagram::Object::depacketize(&mut read_cur).unwrap(); - let subscribe_id = 0; let track_alias = 1; let group_id = 2; let object_id = 3; let publisher_priority = 4; - let object_status = Some(ObjectStatus::DoesNotExist); - let object_payload = vec![]; + let even_header_type = 12; + let even_value = 1; + let even_header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(even_value)); + let odd_header_type = 9; + let odd_value = vec![1, 2, 3]; + let odd_header_value = + ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(odd_value)); + + let extension_headers = vec![ + ExtensionHeader::new(odd_header_type, odd_header_value).unwrap(), + ExtensionHeader::new(even_header_type, even_header_value).unwrap(), + ]; + let object_payload = vec![0, 1, 2]; let expected_datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); @@ -398,41 +488,16 @@ mod tests { use bytes::BytesMut; use std::io::Cursor; - use crate::messages::data_streams::{datagram, object_status::ObjectStatus, DataStreams}; - - #[test] - fn packetize_datagram_object_not_normal_and_not_empty_payload() { - let subscribe_id = 0; - let track_alias = 1; - let group_id = 2; - let object_id = 3; - let publisher_priority = 4; - let object_status = Some(ObjectStatus::EndOfTrackAndGroup); - let object_payload = vec![0, 1, 2]; - - let datagram_object = datagram::Object::new( - subscribe_id, - track_alias, - group_id, - object_id, - publisher_priority, - object_status, - object_payload, - ); - - assert!(datagram_object.is_err()); - } + use crate::messages::data_streams::{datagram, DataStreams}; #[test] - fn depacketize_datagram_object_wrong_object_status() { + fn depacketize_datagram_object_with_empty_payload() { let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) - 0, // Object Payload Length (i) - 2, // Object Status (i) + 0, // Extension Headers Length (i) ]; let mut buf = BytesMut::with_capacity(bytes_array.len()); buf.extend_from_slice(&bytes_array); diff --git a/moqt-core/src/modules/messages/data_streams/datagram_status.rs b/moqt-core/src/modules/messages/data_streams/datagram_status.rs new file mode 100644 index 00000000..6ec6deeb --- /dev/null +++ b/moqt-core/src/modules/messages/data_streams/datagram_status.rs @@ -0,0 +1,521 @@ +use super::{extension_header::ExtensionHeader, object_status::ObjectStatus}; +use crate::{ + messages::data_streams::DataStreams, + variable_bytes::read_bytes, + variable_integer::{read_variable_integer, write_variable_integer}, +}; +use anyhow::{bail, Context, Result}; +use bytes::{Buf, BytesMut}; +use serde::Serialize; + +/// Implementation of object message per QUIC Datagram. +/// Type of Data Streams: OBJECT_DATAGRAM_STATUS (0x2) +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct Object { + track_alias: u64, + group_id: u64, + object_id: u64, + publisher_priority: u8, + extension_headers_length: u64, + extension_headers: Vec, + object_status: ObjectStatus, +} + +impl Object { + pub fn new( + track_alias: u64, + group_id: u64, + object_id: u64, + publisher_priority: u8, + extension_headers: Vec, + object_status: ObjectStatus, + ) -> Result { + // length of total byte of extension headers + let mut extension_headers_length = 0; + for header in &extension_headers { + extension_headers_length += header.byte_length() as u64; + } + + Ok(Object { + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers_length, + extension_headers, + object_status, + }) + } + + pub fn track_alias(&self) -> u64 { + self.track_alias + } + + pub fn group_id(&self) -> u64 { + self.group_id + } + + pub fn object_id(&self) -> u64 { + self.object_id + } + + pub fn publisher_priority(&self) -> u8 { + self.publisher_priority + } + + pub fn extension_headers(&self) -> &Vec { + &self.extension_headers + } + + pub fn object_status(&self) -> ObjectStatus { + self.object_status + } +} + +impl DataStreams for Object { + fn depacketize(read_cur: &mut std::io::Cursor<&[u8]>) -> Result + where + Self: Sized, + { + let track_alias = read_variable_integer(read_cur).context("track alias")?; + let group_id = read_variable_integer(read_cur).context("group id")?; + let object_id = read_variable_integer(read_cur).context("object id")?; + let publisher_priority = read_bytes(read_cur, 1).context("publisher priority")?[0]; + + let extension_headers_length = + read_variable_integer(read_cur).context("extension headers length")?; + + let mut extension_headers_vec = vec![]; + let extension_headers = + read_bytes(read_cur, extension_headers_length as usize).context("extension headers")?; + let mut extension_headers_cur = std::io::Cursor::new(&extension_headers[..]); + + while extension_headers_cur.has_remaining() { + let extension_header = ExtensionHeader::depacketize(&mut extension_headers_cur) + .context("extension header")?; + extension_headers_vec.push(extension_header); + } + + let object_status_u64 = read_variable_integer(read_cur)?; + let object_status = + match ObjectStatus::try_from(object_status_u64 as u8).context("object status") { + Ok(status) => status, + Err(err) => { + // Any other value SHOULD be treated as a Protocol Violation and terminate the session with a Protocol Violation + // TODO: return Termination Error Code + bail!(err); + } + }; + + tracing::trace!("Depacketized Datagram Object message."); + + Ok(Object { + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers_length, + extension_headers: extension_headers_vec, + object_status, + }) + } + + fn packetize(&self, buf: &mut BytesMut) { + buf.extend(write_variable_integer(self.track_alias)); + buf.extend(write_variable_integer(self.group_id)); + buf.extend(write_variable_integer(self.object_id)); + buf.extend(self.publisher_priority.to_be_bytes()); + + buf.extend(write_variable_integer(self.extension_headers_length)); + for header in &self.extension_headers { + header.packetize(buf); + } + + buf.extend(write_variable_integer(u8::from(self.object_status) as u64)); + + tracing::trace!("Packetized Datagram Object message."); + } +} + +#[cfg(test)] +mod tests { + mod success { + use crate::messages::data_streams::{ + datagram_status, + extension_header::{ExtensionHeader, ExtensionHeaderValue, Value, ValueWithLength}, + object_status::ObjectStatus, + DataStreams, + }; + use bytes::BytesMut; + use std::io::Cursor; + + #[test] + fn packetize_datagram_status() { + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let extension_headers = vec![]; + let object_status = ObjectStatus::Normal; + + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + datagram_object.packetize(&mut buf); + + let expected_bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 0, // Extension Headers Length (i) + 0, // Object Status (i) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn depacketize_datagram_status() { + let bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 0, // Extension Headers Length (i) + 1, // Object Status (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_datagram_object = + datagram_status::Object::depacketize(&mut read_cur).unwrap(); + + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let extension_headers = vec![]; + let object_status = ObjectStatus::DoesNotExist; + + let expected_datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + assert_eq!(depacketized_datagram_object, expected_datagram_object); + } + + #[test] + fn packetize_datagram_status_with_even_type_extension_header() { + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let header_type = 4; + let value = 1; + let header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = ObjectStatus::EndOfGroup; + + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + datagram_object.packetize(&mut buf); + + let expected_bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 2, // Extension Headers Length (i) + 4, // Header Type (i) + 1, // Header Value (i) + 3, // Object Status (i) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn packetize_datagram_status_with_odd_type_extension_header() { + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let header_type = 1; + let value = vec![1, 2, 3]; + let header_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = ObjectStatus::EndOfTrackAndGroup; + + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + datagram_object.packetize(&mut buf); + + let expected_bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 5, // Extension Headers Length (i) + 1, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + 4, // Object Status (i) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn packetize_datagram_status_with_mixed_type_extension_headers() { + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let even_header_type = 12; + let even_value = 1; + let even_header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(even_value)); + let odd_header_type = 9; + let odd_value = vec![1, 2, 3]; + let odd_header_value = + ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(odd_value)); + + let extension_headers = vec![ + ExtensionHeader::new(odd_header_type, odd_header_value).unwrap(), + ExtensionHeader::new(even_header_type, even_header_value).unwrap(), + ]; + let object_status = ObjectStatus::EndOfTrack; + + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + datagram_object.packetize(&mut buf); + + let expected_bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 7, // Extension Headers Length (i) + //{ + 9, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + // }{ + 12, // Header Type (i) + 1, // Header Value (i) + // } + 5, // Object Status (i) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn depacketize_datagram_status_with_even_type_extension_header() { + let bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 2, // Extension Headers Length (i) + 4, // Header Type (i) + 1, // Header Value (i) + 0, // Object Status (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_datagram_object = + datagram_status::Object::depacketize(&mut read_cur).unwrap(); + + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let header_type = 4; + let value = 1; + let header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = ObjectStatus::Normal; + + let expected_datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + assert_eq!(depacketized_datagram_object, expected_datagram_object); + } + + #[test] + fn depacketize_datagram_status_with_odd_type_extension_header() { + let bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 5, // Extension Headers Length (i) + 1, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + 0, // Object Status (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_datagram_object = + datagram_status::Object::depacketize(&mut read_cur).unwrap(); + + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let header_type = 1; + let value = vec![1, 2, 3]; + let header_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = ObjectStatus::Normal; + + let expected_datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + assert_eq!(depacketized_datagram_object, expected_datagram_object); + } + + #[test] + fn depacketize_datagram_status_with_mixed_type_extension_headers() { + let bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 7, // Extension Headers Length (i) + //{ + 9, // Header Type (i) + 3, // Header Value Length (i) + 1, 2, 3, // Header Value (..) + // }{ + 12, // Header Type (i) + 1, // Header Value (i) + // } + 0, // Object Status (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_datagram_object = + datagram_status::Object::depacketize(&mut read_cur).unwrap(); + + let track_alias = 1; + let group_id = 2; + let object_id = 3; + let publisher_priority = 4; + let even_header_type = 12; + let even_value = 1; + let even_header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(even_value)); + let odd_header_type = 9; + let odd_value = vec![1, 2, 3]; + let odd_header_value = + ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(odd_value)); + + let extension_headers = vec![ + ExtensionHeader::new(odd_header_type, odd_header_value).unwrap(), + ExtensionHeader::new(even_header_type, even_header_value).unwrap(), + ]; + let object_status = ObjectStatus::Normal; + + let expected_datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + + assert_eq!(depacketized_datagram_object, expected_datagram_object); + } + } + + mod failure { + use bytes::BytesMut; + use std::io::Cursor; + + use crate::messages::data_streams::{datagram_status, DataStreams}; + + #[test] + fn depacketize_datagram_status_with_unknown_status() { + let bytes_array = [ + 1, // Track Alias (i) + 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 0, // Extension Headers Length (i) + 20, // Object Status (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_datagram_object = datagram_status::Object::depacketize(&mut read_cur); + + assert!(depacketized_datagram_object.is_err()); + } + } +} diff --git a/moqt-core/src/modules/messages/data_streams/extension_header.rs b/moqt-core/src/modules/messages/data_streams/extension_header.rs new file mode 100644 index 00000000..c082c311 --- /dev/null +++ b/moqt-core/src/modules/messages/data_streams/extension_header.rs @@ -0,0 +1,172 @@ +use crate::{ + messages::data_streams::DataStreams, + variable_bytes::read_bytes, + variable_integer::{read_variable_integer, write_variable_integer}, +}; +use anyhow::{bail, Context, Result}; +use bytes::BytesMut; +use serde::Serialize; + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct ExtensionHeader { + header_type: u64, + value: ExtensionHeaderValue, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub enum ExtensionHeaderValue { + EvenTypeValue(Value), + OddTypeValue(ValueWithLength), +} + +impl ExtensionHeader { + pub fn new(header_type: u64, value: ExtensionHeaderValue) -> Result { + if header_type % 2 == 0 && matches!(value, ExtensionHeaderValue::OddTypeValue(_)) { + bail!("Mismatched value type: expected even, but got odd"); + } + + if header_type % 2 != 0 && matches!(value, ExtensionHeaderValue::EvenTypeValue(_)) { + bail!("Mismatched value type: expected odd, but got even"); + } + + Ok(ExtensionHeader { header_type, value }) + } + + pub fn byte_length(&self) -> usize { + let mut len = write_variable_integer(self.header_type).len(); + match &self.value { + ExtensionHeaderValue::EvenTypeValue(value) => len += value.byte_length(), + ExtensionHeaderValue::OddTypeValue(value_with_length) => { + len += value_with_length.byte_length() + } + } + len + } +} + +impl DataStreams for ExtensionHeader { + fn depacketize(read_cur: &mut std::io::Cursor<&[u8]>) -> Result + where + Self: Sized, + { + let header_type = read_variable_integer(read_cur).context("header type")?; + if header_type % 2 == 0 { + let value = ExtensionHeaderValue::EvenTypeValue(Value::depacketize(read_cur)?); + Ok(ExtensionHeader { header_type, value }) + } else { + let value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::depacketize(read_cur)?); + Ok(ExtensionHeader { header_type, value }) + } + } + + fn packetize(&self, buf: &mut BytesMut) { + buf.extend(write_variable_integer(self.header_type)); + match &self.value { + ExtensionHeaderValue::EvenTypeValue(value) => value.packetize(buf), + ExtensionHeaderValue::OddTypeValue(value_with_length) => { + value_with_length.packetize(buf) + } + } + } +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct Value { + header_value: u64, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct ValueWithLength { + header_length: u64, + header_value: Vec, +} + +impl Value { + pub fn new(header_value: u64) -> Self { + Value { header_value } + } + + pub fn byte_length(&self) -> usize { + write_variable_integer(self.header_value).len() + } +} + +impl DataStreams for Value { + fn depacketize(read_cur: &mut std::io::Cursor<&[u8]>) -> Result + where + Self: Sized, + { + let header_value = read_variable_integer(read_cur).context("header length")?; + + Ok(Value { header_value }) + } + + fn packetize(&self, buf: &mut BytesMut) { + buf.extend(write_variable_integer(self.header_value)); + } +} + +impl ValueWithLength { + pub fn new(header_value: Vec) -> Self { + ValueWithLength { + header_length: header_value.len() as u64, + header_value, + } + } + + pub fn byte_length(&self) -> usize { + let mut len = write_variable_integer(self.header_length).len(); + len += self.header_value.len(); + len + } +} + +impl DataStreams for ValueWithLength { + fn depacketize(read_cur: &mut std::io::Cursor<&[u8]>) -> Result + where + Self: Sized, + { + let header_length = read_variable_integer(read_cur).context("header length")?; + let header_value = if header_length > 0 { + read_bytes(read_cur, header_length as usize).context("header value")? + } else { + vec![] + }; + + Ok(ValueWithLength { + header_length, + header_value, + }) + } + + fn packetize(&self, buf: &mut BytesMut) { + buf.extend(write_variable_integer(self.header_length)); + buf.extend(&self.header_value); + } +} + +#[cfg(test)] +mod failure { + use super::ValueWithLength; + use crate::messages::data_streams::extension_header::{ + ExtensionHeader, ExtensionHeaderValue, Value, + }; + + #[test] + fn new_odd_value_with_even_type() { + let even_header_type = 0; + let odd_type_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(vec![0])); + let extension_header = ExtensionHeader::new(even_header_type, odd_type_value); + + assert!(extension_header.is_err()); + } + + #[test] + fn new_even_value_with_odd_type() { + let odd_header_type = 1; + let even_type_value = ExtensionHeaderValue::EvenTypeValue(Value::new(0)); + let extension_header = ExtensionHeader::new(odd_header_type, even_type_value); + + assert!(extension_header.is_err()); + } +} diff --git a/moqt-core/src/modules/messages/data_streams/object_status.rs b/moqt-core/src/modules/messages/data_streams/object_status.rs index cf2c9ae6..e8625af3 100644 --- a/moqt-core/src/modules/messages/data_streams/object_status.rs +++ b/moqt-core/src/modules/messages/data_streams/object_status.rs @@ -8,5 +8,5 @@ pub enum ObjectStatus { DoesNotExist = 0x1, EndOfGroup = 0x3, EndOfTrackAndGroup = 0x4, - EndOfSubgroup = 0x5, + EndOfTrack = 0x5, } diff --git a/moqt-core/src/modules/messages/data_streams/subgroup_stream.rs b/moqt-core/src/modules/messages/data_streams/subgroup_stream.rs index bf421117..d71cb059 100644 --- a/moqt-core/src/modules/messages/data_streams/subgroup_stream.rs +++ b/moqt-core/src/modules/messages/data_streams/subgroup_stream.rs @@ -1,19 +1,18 @@ -use super::object_status::ObjectStatus; +use super::{extension_header::ExtensionHeader, object_status::ObjectStatus}; use crate::{ messages::data_streams::DataStreams, - variable_bytes::read_fixed_length_bytes, + variable_bytes::read_bytes, variable_integer::{read_variable_integer, write_variable_integer}, }; use anyhow::{bail, Context, Result}; -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use serde::Serialize; /// Implementation of header message on QUIC Stream per Subgroup. /// Object messages are sent following this message. -/// Type of Data Streams:STREAM_HEADER_SUBGROUP (0x4) +/// Type of Data Streams:SUBGROUP_HEADER (0x4) #[derive(Debug, Clone, Serialize, PartialEq, Default)] pub struct Header { - subscribe_id: u64, track_alias: u64, group_id: u64, subgroup_id: u64, @@ -22,14 +21,12 @@ pub struct Header { impl Header { pub fn new( - subscribe_id: u64, track_alias: u64, group_id: u64, subgroup_id: u64, publisher_priority: u8, ) -> Result { Ok(Header { - subscribe_id, track_alias, group_id, subgroup_id, @@ -37,10 +34,6 @@ impl Header { }) } - pub fn subscribe_id(&self) -> u64 { - self.subscribe_id - } - pub fn track_alias(&self) -> u64 { self.track_alias } @@ -63,17 +56,14 @@ impl DataStreams for Header { where Self: Sized, { - let subscribe_id = read_variable_integer(read_cur).context("subscribe id")?; let track_alias = read_variable_integer(read_cur).context("track alias")?; let group_id = read_variable_integer(read_cur).context("group id")?; let subgroup_id = read_variable_integer(read_cur).context("subgroup id")?; - let publisher_priority = - read_fixed_length_bytes(read_cur, 1).context("publisher priority")?[0]; + let publisher_priority = read_bytes(read_cur, 1).context("publisher priority")?[0]; tracing::trace!("Depacketized Subgroup Stream Header message."); Ok(Header { - subscribe_id, track_alias, group_id, subgroup_id, @@ -82,7 +72,6 @@ impl DataStreams for Header { } fn packetize(&self, buf: &mut BytesMut) { - buf.extend(write_variable_integer(self.subscribe_id)); buf.extend(write_variable_integer(self.track_alias)); buf.extend(write_variable_integer(self.group_id)); buf.extend(write_variable_integer(self.subgroup_id)); @@ -97,6 +86,8 @@ impl DataStreams for Header { #[derive(Debug, Clone, Serialize, PartialEq)] pub struct Object { object_id: u64, + extension_headers_length: u64, + extension_headers: Vec, object_payload_length: u64, object_status: Option, object_payload: Vec, @@ -105,6 +96,7 @@ pub struct Object { impl Object { pub fn new( object_id: u64, + extension_headers: Vec, object_status: Option, object_payload: Vec, ) -> Result { @@ -122,8 +114,16 @@ impl Object { } } + // length of total byte of extension headers + let mut extension_headers_length = 0; + for header in &extension_headers { + extension_headers_length += header.byte_length() as u64; + } + Ok(Object { object_id, + extension_headers_length, + extension_headers, object_payload_length, object_status, object_payload, @@ -145,6 +145,20 @@ impl DataStreams for Object { Self: Sized, { let object_id = read_variable_integer(read_cur).context("object id")?; + let extension_headers_length = + read_variable_integer(read_cur).context("extension headers length")?; + + let mut extension_headers_vec = vec![]; + let extension_headers = + read_bytes(read_cur, extension_headers_length as usize).context("extension headers")?; + let mut extension_headers_cur = std::io::Cursor::new(&extension_headers[..]); + + while extension_headers_cur.has_remaining() { + let extension_header = ExtensionHeader::depacketize(&mut extension_headers_cur) + .context("extension header")?; + extension_headers_vec.push(extension_header); + } + let object_payload_length = read_variable_integer(read_cur).context("object payload length")?; @@ -167,8 +181,7 @@ impl DataStreams for Object { }; let object_payload = if object_payload_length > 0 { - read_fixed_length_bytes(read_cur, object_payload_length as usize) - .context("object payload")? + read_bytes(read_cur, object_payload_length as usize).context("object payload")? } else { vec![] }; @@ -177,6 +190,8 @@ impl DataStreams for Object { Ok(Object { object_id, + extension_headers_length, + extension_headers: extension_headers_vec, object_payload_length, object_status, object_payload, @@ -185,6 +200,12 @@ impl DataStreams for Object { fn packetize(&self, buf: &mut BytesMut) { buf.extend(write_variable_integer(self.object_id)); + + buf.extend(write_variable_integer(self.extension_headers_length)); + for header in &self.extension_headers { + header.packetize(buf); + } + buf.extend(write_variable_integer(self.object_payload_length)); if self.object_status.is_some() { buf.extend(write_variable_integer( @@ -200,24 +221,23 @@ impl DataStreams for Object { #[cfg(test)] mod tests { mod success { - use bytes::BytesMut; - use std::io::Cursor; - + use crate::messages::data_streams::extension_header::{ + ExtensionHeader, ExtensionHeaderValue, Value, ValueWithLength, + }; use crate::messages::data_streams::{ - object_status::ObjectStatus, - {subgroup_stream, DataStreams}, + object_status::ObjectStatus, subgroup_stream, DataStreams, }; + use bytes::BytesMut; + use std::io::Cursor; #[test] fn packetize_subgroup_stream_header() { - let subscribe_id = 0; let track_alias = 1; let group_id = 2; let subgroup_id = 3; let publisher_priority = 4; let subgroup_stream_header = subgroup_stream::Header::new( - subscribe_id, track_alias, group_id, subgroup_id, @@ -229,7 +249,6 @@ mod tests { subgroup_stream_header.packetize(&mut buf); let expected_bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Subgroup ID (i) @@ -242,7 +261,6 @@ mod tests { #[test] fn depacketize_subgroup_stream_header() { let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Subgroup ID (i) @@ -254,14 +272,12 @@ mod tests { let depacketized_subgroup_stream_header = subgroup_stream::Header::depacketize(&mut read_cur).unwrap(); - let subscribe_id = 0; let track_alias = 1; let group_id = 2; let subgroup_id = 3; let publisher_priority = 4; let expected_subgroup_stream_header = subgroup_stream::Header::new( - subscribe_id, track_alias, group_id, subgroup_id, @@ -278,18 +294,25 @@ mod tests { #[test] fn packetize_subgroup_stream_object_normal() { let object_id = 0; + let extension_headers = vec![]; let object_status = None; let object_payload = vec![0, 1, 2]; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); let mut buf = BytesMut::new(); subgroup_stream_object.packetize(&mut buf); let expected_bytes_array = [ 0, // Object ID (i) - 3, // Object Payload Length (i + 0, // Extension Headers Length (i) + 3, // Object Payload Length (i) 0, 1, 2, // Object Payload (..) ]; @@ -299,17 +322,24 @@ mod tests { #[test] fn packetize_subgroup_stream_object_normal_and_empty_payload() { let object_id = 0; + let extension_headers = vec![]; let object_status = Some(ObjectStatus::Normal); let object_payload = vec![]; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); let mut buf = BytesMut::new(); subgroup_stream_object.packetize(&mut buf); let expected_bytes_array = [ 0, // Object ID (i) + 0, // Extension Headers Length (i) 0, // Object Payload Length (i) 0, // Object Status (i) ]; @@ -320,17 +350,24 @@ mod tests { #[test] fn packetize_subgroup_stream_object_not_normal() { let object_id = 0; + let extension_headers = vec![]; let object_status = Some(ObjectStatus::EndOfGroup); let object_payload = vec![]; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); let mut buf = BytesMut::new(); subgroup_stream_object.packetize(&mut buf); let expected_bytes_array = [ 0, // Object ID (i) + 0, // Extension Headers Length (i) 0, // Object Payload Length (i) 3, // Object Status (i) ]; @@ -342,6 +379,7 @@ mod tests { fn depacketize_subgroup_stream_object_normal() { let bytes_array = [ 0, // Object ID (i) + 0, // Extension Headers Length (i) 0, // Object Payload Length (i) 0, // Object Status (i) ]; @@ -352,11 +390,17 @@ mod tests { subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); let object_id = 0; + let extension_headers = vec![]; let object_status = Some(ObjectStatus::Normal); let object_payload = vec![]; - let expected_subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let expected_subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); assert_eq!( depacketized_subgroup_stream_object, @@ -368,6 +412,7 @@ mod tests { fn depacketize_subgroup_stream_object_normal_and_empty_payload() { let bytes_array = [ 0, // Object ID (i) + 0, // Extension Headers Length (i) 0, // Object Payload Length (i) 0, // Object Status (i) ]; @@ -378,11 +423,17 @@ mod tests { subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); let object_id = 0; + let extension_headers = vec![]; let object_status = Some(ObjectStatus::Normal); let object_payload = vec![]; - let expected_subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let expected_subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); assert_eq!( depacketized_subgroup_stream_object, @@ -394,6 +445,7 @@ mod tests { fn depacketize_subgroup_stream_object_not_normal() { let bytes_array = [ 0, // Object ID (i) + 0, // Extension Headers Length (i) 0, // Object Payload Length (i) 1, // Object Status (i) ]; @@ -404,17 +456,289 @@ mod tests { subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); let object_id = 0; + let extension_headers = vec![]; let object_status = Some(ObjectStatus::DoesNotExist); let object_payload = vec![]; - let expected_subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let expected_subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); + + assert_eq!( + depacketized_subgroup_stream_object, + expected_subgroup_stream_object + ); + } + + #[test] + fn packetize_subgroup_stream_object_with_even_type_extension_header() { + let object_id = 0; + let header_type = 0; + let value = 1; + let header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = None; + let object_payload = vec![1, 2, 3]; + + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + subgroup_stream_object.packetize(&mut buf); + + let expected_bytes_array = [ + 0, // Object ID (i) + 2, // Extension Headers Length (i) + 0, // Header Type (i) + 1, // Header Value (i) + 3, // Object Payload Length (i) + 1, 2, 3, // Object Payload (..) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn packetize_subgroup_stream_object_with_odd_type_extension_header() { + let object_id = 0; + let header_type = 1; + let value = vec![116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, 54]; + let header_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = Some(ObjectStatus::Normal); + let object_payload = vec![]; + + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + subgroup_stream_object.packetize(&mut buf); + + let expected_bytes_array = [ + 0, // Object ID (i) + 16, // Extension Headers Length (i) + 1, // Header Type (i) + 14, // Header Length (i) + 116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, + 54, // Header Value (..) + 0, // Object Payload Length (i) + 0, // Object Status (i) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn packetize_subgroup_stream_object_with_mixed_type_extension_headers() { + let object_id = 0; + + let even_header_type = 4; + let even_value = 3; + let even_header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(even_value)); + + let odd_header_type = 5; + let odd_value = vec![116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, 54]; + let odd_header_value = + ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(odd_value)); + + let extension_headers = vec![ + ExtensionHeader::new(even_header_type, even_header_value).unwrap(), + ExtensionHeader::new(odd_header_type, odd_header_value).unwrap(), + ]; + let object_status = Some(ObjectStatus::Normal); + let object_payload = vec![]; + + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); + + let mut buf = BytesMut::new(); + subgroup_stream_object.packetize(&mut buf); + + let expected_bytes_array = [ + 0, // Object ID (i) + 18, // Extension Headers Length (i) + // { + 4, // Header Type (i) + 3, // Header Value (i) + // }{ + 5, // Header Type (i) + 14, // Header Length (i) + 116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, + 54, // Header Value (..) + // } + 0, // Object Payload Length (i) + 0, // Object Status (i) + ]; + + assert_eq!(buf.as_ref(), expected_bytes_array); + } + + #[test] + fn depacketize_subgroup_stream_object_with_even_type_extension_header() { + let bytes_array = [ + 0, // Object ID (i) + 2, // Extension Headers Length (i) + 0, // Header Type (i) + 1, // Header Value (i) + 0, // Object Payload Length (i) + 0, // Object Status (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_subgroup_stream_object = + subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); + + let object_id = 0; + let header_type = 0; + let value = 1; + let header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = Some(ObjectStatus::Normal); + let object_payload = vec![]; + + let expected_subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); + + assert_eq!( + depacketized_subgroup_stream_object, + expected_subgroup_stream_object + ); + } + + #[test] + fn depacketize_subgroup_stream_object_with_odd_type_extension_header() { + let bytes_array = [ + 0, // Object ID (i) + 16, // Extension Headers Length (i) + 1, // Header Type (i) + 14, // Header Length (i) + 116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, + 54, // Header Value (..) + 3, // Object Payload Length (i) + 1, 2, 3, // Object Payload (..) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_subgroup_stream_object = + subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); + + let object_id = 0; + let header_type = 1; + let value = vec![116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, 54]; + let header_value = ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(value)); + + let extension_headers = vec![ExtensionHeader::new(header_type, header_value).unwrap()]; + let object_status = None; + let object_payload = vec![1, 2, 3]; + + let expected_subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); assert_eq!( depacketized_subgroup_stream_object, expected_subgroup_stream_object ); } + + #[test] + fn depacketize_subgroup_stream_object_with_mixed_type_extension_headers() { + let bytes_array = [ + 0, // Object ID (i) + 18, // Extension Headers Length (i) + // { + 4, // Header Type (i) + 3, // Header Value (i) + // }{ + 5, // Header Type (i) + 14, // Header Length (i) + 116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, + 54, // Header Value (..) + // } + 0, // Object Payload Length (i) + 0, // Object Status (i) + ]; + + let object_id = 0; + + let even_header_type = 4; + let even_value = 3; + let even_header_value = ExtensionHeaderValue::EvenTypeValue(Value::new(even_value)); + + let odd_header_type = 5; + let odd_value = vec![116, 114, 97, 99, 101, 73, 68, 58, 49, 50, 51, 52, 53, 54]; + let odd_header_value = + ExtensionHeaderValue::OddTypeValue(ValueWithLength::new(odd_value)); + + let extension_headers = vec![ + ExtensionHeader::new(even_header_type, even_header_value).unwrap(), + ExtensionHeader::new(odd_header_type, odd_header_value).unwrap(), + ]; + let object_status = Some(ObjectStatus::Normal); + let object_payload = vec![]; + + let expected_subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ) + .unwrap(); + + println!( + "expected_subgroup_stream_object: {:?}", + expected_subgroup_stream_object + ); + + let mut buf: BytesMut = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf[..]); + let depacketized_subgroup_stream_object = + subgroup_stream::Object::depacketize(&mut read_cur); + + println!( + "depacketized_subgroup_stream_object: {:?}", + depacketized_subgroup_stream_object + ); + + assert_eq!( + depacketized_subgroup_stream_object.unwrap(), + expected_subgroup_stream_object + ); + } } mod failure { @@ -428,11 +752,16 @@ mod tests { #[test] fn packetize_subgroup_stream_object_not_normal_and_not_empty_payload() { let object_id = 0; + let extension_headers = vec![]; let object_status = Some(ObjectStatus::EndOfTrackAndGroup); let object_payload = vec![0, 1, 2]; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, + ); assert!(subgroup_stream_object.is_err()); } @@ -441,6 +770,7 @@ mod tests { fn depacketize_subgroup_stream_object_wrong_object_status() { let bytes_array = [ 0, // Object ID (i) + 0, // Extension Headers Length (i) 0, // Object Payload Length (i) 2, // Object Status (i) ]; diff --git a/moqt-core/src/modules/messages/data_streams/track_stream.rs b/moqt-core/src/modules/messages/data_streams/track_stream.rs deleted file mode 100644 index cddbac9e..00000000 --- a/moqt-core/src/modules/messages/data_streams/track_stream.rs +++ /dev/null @@ -1,439 +0,0 @@ -use crate::{ - messages::data_streams::DataStreams, - variable_bytes::read_fixed_length_bytes, - variable_integer::{read_variable_integer, write_variable_integer}, -}; -use anyhow::{bail, Context, Result}; -use bytes::BytesMut; -use serde::Serialize; - -use super::object_status::ObjectStatus; - -/// Implementation of header message on QUIC Stream per Track. -/// TrackObject messages are sent following this message. -/// Type of Data Streams: STREAM_HEADER_TRACK (0x2) -#[derive(Debug, Clone, Serialize, PartialEq, Default)] -pub struct Header { - subscribe_id: u64, - track_alias: u64, - publisher_priority: u8, -} - -impl Header { - pub fn new(subscribe_id: u64, track_alias: u64, publisher_priority: u8) -> Result { - Ok(Header { - subscribe_id, - track_alias, - publisher_priority, - }) - } - - pub fn subscribe_id(&self) -> u64 { - self.subscribe_id - } - - pub fn track_alias(&self) -> u64 { - self.track_alias - } - - pub fn publisher_priority(&self) -> u8 { - self.publisher_priority - } -} - -impl DataStreams for Header { - fn depacketize(read_cur: &mut std::io::Cursor<&[u8]>) -> Result - where - Self: Sized, - { - let subscribe_id = read_variable_integer(read_cur).context("subscribe id")?; - let track_alias = read_variable_integer(read_cur).context("track alias")?; - let publisher_priority = - read_fixed_length_bytes(read_cur, 1).context("publisher priority")?[0]; - - tracing::trace!("Depacketized Track Stream Header message."); - - Ok(Header { - subscribe_id, - track_alias, - publisher_priority, - }) - } - - fn packetize(&self, buf: &mut BytesMut) { - buf.extend(write_variable_integer(self.subscribe_id)); - buf.extend(write_variable_integer(self.track_alias)); - buf.extend(self.publisher_priority.to_be_bytes()); - - tracing::trace!("Packetized Track Stream Header message."); - } -} - -/// Implementation of object message on QUIC Stream per Track. -/// This message is sent following TrackHeader message. -#[derive(Debug, Clone, Serialize, PartialEq)] -pub struct Object { - group_id: u64, - object_id: u64, - object_payload_length: u64, - object_status: Option, - object_payload: Vec, -} - -impl Object { - pub fn new( - group_id: u64, - object_id: u64, - object_status: Option, - object_payload: Vec, - ) -> Result { - let object_payload_length = object_payload.len() as u64; - - if object_status.is_some() && object_payload_length != 0 { - bail!("The Object Status field is only sent if the Object Payload Length is zero."); - } - - // Any object with a status code other than zero MUST have an empty payload. - if let Some(status) = object_status { - if status != ObjectStatus::Normal && object_payload_length != 0 { - // TODO: return Termination Error Code - bail!("Any object with a status code other than zero MUST have an empty payload."); - } - } - - Ok(Object { - group_id, - object_id, - object_payload_length, - object_status, - object_payload, - }) - } - - pub fn group_id(&self) -> u64 { - self.group_id - } - - pub fn object_id(&self) -> u64 { - self.object_id - } - - pub fn object_status(&self) -> Option { - self.object_status - } -} - -impl DataStreams for Object { - fn depacketize(read_cur: &mut std::io::Cursor<&[u8]>) -> Result - where - Self: Sized, - { - let group_id = read_variable_integer(read_cur).context("group id")?; - let object_id = read_variable_integer(read_cur).context("object id")?; - let object_payload_length = - read_variable_integer(read_cur).context("object payload length")?; - - // If the length of the remaining buf is larger than object_payload_length, object_status exists. - let object_status = if object_payload_length == 0 { - let object_status_u64 = read_variable_integer(read_cur)?; - let object_status = - match ObjectStatus::try_from(object_status_u64 as u8).context("object status") { - Ok(status) => status, - Err(err) => { - // Any other value SHOULD be treated as a Protocol Violation and terminate the session with a Protocol Violation - // TODO: return Termination Error Code - bail!(err); - } - }; - - Some(object_status) - } else { - None - }; - - let object_payload = if object_payload_length > 0 { - read_fixed_length_bytes(read_cur, object_payload_length as usize) - .context("object payload")? - } else { - vec![] - }; - - tracing::trace!("Depacketized Track Stream Object message."); - - Ok(Object { - group_id, - object_id, - object_payload_length, - object_status, - object_payload, - }) - } - - fn packetize(&self, buf: &mut BytesMut) { - buf.extend(write_variable_integer(self.group_id)); - buf.extend(write_variable_integer(self.object_id)); - buf.extend(write_variable_integer(self.object_payload_length)); - if self.object_status.is_some() { - buf.extend(write_variable_integer( - u8::from(self.object_status.unwrap()) as u64, - )); - } - buf.extend(&self.object_payload); - - tracing::trace!("Packetized Track Stream Object message."); - } -} - -#[cfg(test)] -mod tests { - mod success { - use crate::messages::data_streams::{ - object_status::ObjectStatus, track_stream, track_stream::DataStreams, - }; - use bytes::BytesMut; - use std::io::Cursor; - - #[test] - fn packetize_track_stream_header() { - let subscribe_id = 0; - let track_alias = 1; - let publisher_priority = 2; - - let track_stream_header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - let mut buf = BytesMut::new(); - track_stream_header.packetize(&mut buf); - - let expected_bytes_array = [ - 0, // Subscribe ID (i) - 1, // Track Alias (i) - 2, // Subscriber Priority (8) - ]; - - assert_eq!(buf.as_ref(), expected_bytes_array); - } - - #[test] - fn depacketize_track_stream_header() { - let bytes_array = [ - 0, // Subscribe ID (i) - 1, // Track Alias (i) - 2, // Subscriber Priority (8) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let mut read_cur = Cursor::new(&buf[..]); - let depacketized_track_stream_header = - track_stream::Header::depacketize(&mut read_cur).unwrap(); - - let subscribe_id = 0; - let track_alias = 1; - let publisher_priority = 2; - - let expected_track_stream_header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - assert_eq!( - depacketized_track_stream_header, - expected_track_stream_header - ); - } - - #[test] - fn packetize_track_stream_object_normal() { - let group_id = 0; - let object_id = 1; - let object_status = None; - let object_payload = vec![0, 1, 2]; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - let mut buf = BytesMut::new(); - track_stream_object.packetize(&mut buf); - - let expected_bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 3, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - - assert_eq!(buf.as_ref(), expected_bytes_array); - } - - #[test] - fn packetize_track_stream_object_normal_and_empty_payload() { - let group_id = 0; - let object_id = 1; - let object_status = Some(ObjectStatus::Normal); - let object_payload = vec![]; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - let mut buf = BytesMut::new(); - track_stream_object.packetize(&mut buf); - - let expected_bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 0, // Object Payload Length (i) - 0, // Object Status (i) - ]; - - assert_eq!(buf.as_ref(), expected_bytes_array); - } - - #[test] - fn packetize_track_stream_object_not_normal() { - let group_id = 0; - let object_id = 1; - let object_status = Some(ObjectStatus::EndOfGroup); - let object_payload = vec![]; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - let mut buf = BytesMut::new(); - track_stream_object.packetize(&mut buf); - - let expected_bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 0, // Object Payload Length (i) - 3, // Object Status (i) - ]; - - assert_eq!(buf.as_ref(), expected_bytes_array); - } - - #[test] - fn depacketize_track_stream_object_normal() { - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 3, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let mut read_cur = Cursor::new(&buf[..]); - let depacketized_track_stream_object = - track_stream::Object::depacketize(&mut read_cur).unwrap(); - - let group_id = 0; - let object_id = 1; - let object_status = None; - let object_payload = vec![0, 1, 2]; - - let expected_track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - assert_eq!( - depacketized_track_stream_object, - expected_track_stream_object - ); - } - - #[test] - fn depacketize_track_stream_object_normal_and_empty_payload() { - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 0, // Object Payload Length (i) - 0, // Object Status (i) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let mut read_cur = Cursor::new(&buf[..]); - let depacketized_track_stream_object = - track_stream::Object::depacketize(&mut read_cur).unwrap(); - - let group_id = 0; - let object_id = 1; - let object_status = Some(ObjectStatus::Normal); - let object_payload = vec![]; - - let expected_track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - assert_eq!( - depacketized_track_stream_object, - expected_track_stream_object - ); - } - - #[test] - fn depacketize_track_stream_object_not_normal() { - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 0, // Object Payload Length (i) - 1, // Object Status (i) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let mut read_cur = Cursor::new(&buf[..]); - let depacketized_track_stream_object = - track_stream::Object::depacketize(&mut read_cur).unwrap(); - - let group_id = 0; - let object_id = 1; - let object_status = Some(ObjectStatus::DoesNotExist); - let object_payload = vec![]; - - let expected_track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - assert_eq!( - depacketized_track_stream_object, - expected_track_stream_object - ); - } - } - - mod failure { - use bytes::BytesMut; - use std::io::Cursor; - - use crate::messages::data_streams::{ - object_status::ObjectStatus, track_stream, DataStreams, - }; - #[test] - fn packetize_track_stream_object_not_normal_and_not_empty_payload() { - let group_id = 0; - let object_id = 1; - let object_status = Some(ObjectStatus::EndOfTrackAndGroup); - let object_payload = vec![0, 1, 2]; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload); - - assert!(track_stream_object.is_err()); - } - - #[test] - fn depacketize_track_stream_object_wrong_object_status() { - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 0, // Object Payload Length (i) - 2, // Object Status (i) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let mut read_cur = Cursor::new(&buf[..]); - let depacketized_track_stream_object = track_stream::Object::depacketize(&mut read_cur); - - assert!(depacketized_track_stream_object.is_err()); - } - } -} diff --git a/moqt-core/src/modules/models.rs b/moqt-core/src/modules/models.rs index 5a741dbb..5d4401a4 100644 --- a/moqt-core/src/modules/models.rs +++ b/moqt-core/src/modules/models.rs @@ -1,2 +1,3 @@ +pub mod range; pub mod subscriptions; pub mod tracks; diff --git a/moqt-core/src/modules/models/range.rs b/moqt-core/src/modules/models/range.rs new file mode 100644 index 00000000..d885c70d --- /dev/null +++ b/moqt-core/src/modules/models/range.rs @@ -0,0 +1,91 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ObjectRange { + start: Option, + end: Option, +} + +impl ObjectRange { + pub fn new( + start_group: Option, + start_object: Option, + end_group: Option, + end_object: Option, + ) -> Self { + let start = match (start_group, start_object) { + (Some(group_id), Some(object_id)) => Some(ObjectStart::new(group_id, object_id)), + _ => None, + }; + + let end = end_group.map(|group_id| ObjectEnd::new(group_id, end_object)); + + // TODO: Validate that start is before end + + Self { start, end } + } + + pub fn start_group(&self) -> Option { + let start = self.start.as_ref()?; + Some(start.group_id()) + } + + pub fn start_object(&self) -> Option { + let start = self.start.as_ref()?; + Some(start.object_id()) + } + + pub fn end_group(&self) -> Option { + let end = self.end.as_ref()?; + Some(end.group_id()) + } + + pub fn end_object(&self) -> Option { + let end = self.end.as_ref()?; + end.object_id() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ObjectStart { + group_id: u64, + object_id: u64, +} + +impl ObjectStart { + pub fn new(group_id: u64, object_id: u64) -> Self { + Self { + group_id, + object_id, + } + } + + pub fn group_id(&self) -> u64 { + self.group_id + } + + pub fn object_id(&self) -> u64 { + self.object_id + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ObjectEnd { + group_id: u64, + object_id: Option, +} + +impl ObjectEnd { + pub fn new(group_id: u64, object_id: Option) -> Self { + Self { + group_id, + object_id, + } + } + + pub fn group_id(&self) -> u64 { + self.group_id + } + + pub fn object_id(&self) -> Option { + self.object_id + } +} diff --git a/moqt-core/src/modules/models/subscriptions.rs b/moqt-core/src/modules/models/subscriptions.rs index b5129d2a..c2703eef 100644 --- a/moqt-core/src/modules/models/subscriptions.rs +++ b/moqt-core/src/modules/models/subscriptions.rs @@ -1,7 +1,7 @@ pub mod nodes; - +use super::range::{ObjectRange, ObjectStart}; use crate::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, models::tracks::{ForwardingPreference, Track}, }; @@ -17,10 +17,8 @@ pub struct Subscription { priority: u8, group_order: GroupOrder, filter_type: FilterType, - start_group: Option, - start_object: Option, - end_group: Option, - end_object: Option, + requested_object_range: ObjectRange, + actual_object_start: Option, status: Status, } @@ -36,7 +34,6 @@ impl Subscription { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, forwarding_preference: Option, ) -> Self { let track = Track::new( @@ -46,15 +43,15 @@ impl Subscription { forwarding_preference, ); + let requested_object_range = ObjectRange::new(start_group, start_object, end_group, None); + Self { track, priority, group_order, filter_type, - start_group, - start_object, - end_group, - end_object, + requested_object_range, + actual_object_start: None, status: Status::Requesting, } } @@ -80,12 +77,8 @@ impl Subscription { self.filter_type } - pub fn get_absolute_start(&self) -> (Option, Option) { - (self.start_group, self.start_object) - } - - pub fn get_absolute_end(&self) -> (Option, Option) { - (self.end_group, self.end_object) + pub fn get_requested_object_range(&self) -> ObjectRange { + self.requested_object_range.clone() } pub fn set_forwarding_preference(&mut self, forwarding_preference: ForwardingPreference) { @@ -108,18 +101,38 @@ impl Subscription { self.group_order } - pub fn is_end(&self, group_id: u64, object_id: u64) -> bool { - if self.filter_type != FilterType::AbsoluteRange { - return false; - } + pub fn set_stream_id(&mut self, group_id: u64, subgroup_id: u64, stream_id: u64) { + self.track.set_stream_id(group_id, subgroup_id, stream_id); + } + + pub fn get_all_group_ids(&self) -> Vec { + let mut group_ids = self.track.get_all_group_ids(); + group_ids.sort_unstable(); + group_ids + } + + pub fn get_subgroup_ids_for_group(&self, group_id: u64) -> Vec { + let mut subgroup_ids = self.track.get_subgroup_ids_for_group(group_id); + subgroup_ids.sort_unstable(); + subgroup_ids + } + + pub fn get_stream_id_for_subgroup(&self, group_id: u64, subgroup_id: u64) -> Option { + self.track.get_stream_id_for_subgroup(group_id, subgroup_id) + } + + pub fn set_actual_object_start(&mut self, actual_object_start: ObjectStart) { + self.actual_object_start = Some(actual_object_start); + } - group_id == self.end_group.unwrap() && object_id == self.end_object.unwrap() + pub fn get_actual_object_start(&self) -> Option { + self.actual_object_start.clone() } } #[cfg(test)] pub(crate) mod test_helper_fn { - use crate::messages::control_messages::subscribe::{FilterType, GroupOrder}; + use crate::messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}; #[derive(Debug, Clone)] pub(crate) struct SubscriptionVariables { @@ -132,7 +145,6 @@ pub(crate) mod test_helper_fn { pub(crate) start_group: Option, pub(crate) start_object: Option, pub(crate) end_group: Option, - pub(crate) end_object: Option, } pub(crate) fn common_subscription_variable() -> SubscriptionVariables { @@ -145,7 +157,6 @@ pub(crate) mod test_helper_fn { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; SubscriptionVariables { track_alias, @@ -157,7 +168,6 @@ pub(crate) mod test_helper_fn { start_group, start_object, end_group, - end_object, } } } @@ -165,8 +175,9 @@ pub(crate) mod test_helper_fn { #[cfg(test)] mod success { use crate::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, models::{ + range::ObjectStart, subscriptions::{test_helper_fn, Subscription}, tracks::ForwardingPreference, }, @@ -183,8 +194,7 @@ mod success { let start_group = Some(1); let start_object = Some(1); let end_group = Some(1); - let end_object = Some(1); - let forwarding_preference = Some(ForwardingPreference::Track); + let forwarding_preference = Some(ForwardingPreference::Subgroup); let subscription = Subscription::new( track_alias, @@ -196,7 +206,6 @@ mod success { start_group, start_object, end_group, - end_object, forwarding_preference, ); @@ -208,10 +217,15 @@ mod success { assert_eq!(subscription.priority, priority); assert_eq!(subscription.group_order, group_order); assert_eq!(subscription.filter_type, filter_type); - assert_eq!(subscription.start_group, start_group); - assert_eq!(subscription.start_object, start_object); - assert_eq!(subscription.end_group, end_group); - assert_eq!(subscription.end_object, end_object); + assert_eq!( + subscription.requested_object_range.start_group(), + start_group + ); + assert_eq!( + subscription.requested_object_range.start_object(), + start_object + ); + assert_eq!(subscription.requested_object_range.end_group(), end_group); } #[test] @@ -228,7 +242,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -253,7 +266,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -277,7 +289,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -301,7 +312,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -322,7 +332,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -346,7 +355,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -367,7 +375,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -378,7 +385,7 @@ mod success { fn set_and_get_forwarding_preference() { let variable = test_helper_fn::common_subscription_variable(); - let forwarding_preference = ForwardingPreference::Track; + let forwarding_preference = ForwardingPreference::Subgroup; let mut subscription = Subscription::new( variable.track_alias, @@ -390,7 +397,6 @@ mod success { variable.start_group, variable.start_object, variable.end_group, - variable.end_object, None, ); @@ -400,4 +406,156 @@ mod success { assert_eq!(result_forwarding_preference, forwarding_preference); } + + #[test] + fn get_stream_id_for_group() { + let variable = test_helper_fn::common_subscription_variable(); + + let mut subscription = Subscription::new( + variable.track_alias, + variable.track_namespace, + variable.track_name, + variable.subscriber_priority, + variable.group_order, + variable.filter_type, + variable.start_group, + variable.start_object, + variable.end_group, + None, + ); + + let group_id = 0; + let subgroup_ids = vec![0, 1, 2]; + let stream_ids = vec![3, 4, 5]; + + subscription.set_stream_id(group_id, subgroup_ids[0], stream_ids[0]); + subscription.set_stream_id(group_id, subgroup_ids[1], stream_ids[1]); + subscription.set_stream_id(group_id, subgroup_ids[2], stream_ids[2]); + + let result_subgroup_ids = subscription.get_subgroup_ids_for_group(group_id); + + assert_eq!(result_subgroup_ids, subgroup_ids); + + let result_stream_id = vec![ + subscription + .get_stream_id_for_subgroup(group_id, result_subgroup_ids[0]) + .unwrap(), + subscription + .get_stream_id_for_subgroup(group_id, result_subgroup_ids[1]) + .unwrap(), + subscription + .get_stream_id_for_subgroup(group_id, result_subgroup_ids[2]) + .unwrap(), + ]; + + assert_eq!(result_stream_id, stream_ids); + } + + #[test] + fn get_requested_object_range() { + let variable = test_helper_fn::common_subscription_variable(); + + let subscription = Subscription::new( + variable.track_alias, + variable.track_namespace, + variable.track_name, + variable.subscriber_priority, + variable.group_order, + variable.filter_type, + variable.start_group, + variable.start_object, + variable.end_group, + None, + ); + + let result = subscription.get_requested_object_range(); + + assert_eq!(result.start_group(), variable.start_group); + assert_eq!(result.start_object(), variable.start_object); + assert_eq!(result.end_group(), variable.end_group); + } + + #[test] + fn set_actual_object_start() { + let variable = test_helper_fn::common_subscription_variable(); + + let mut subscription = Subscription::new( + variable.track_alias, + variable.track_namespace, + variable.track_name, + variable.subscriber_priority, + variable.group_order, + variable.filter_type, + variable.start_group, + variable.start_object, + variable.end_group, + None, + ); + + let start_group = 1; + let start_object = 1; + + subscription.set_actual_object_start(ObjectStart::new(start_group, start_object)); + + let result = subscription.get_actual_object_start().unwrap(); + + assert_eq!(result.group_id(), start_group); + assert_eq!(result.object_id(), start_object); + } + + #[test] + fn get_actual_object_start() { + let variable = test_helper_fn::common_subscription_variable(); + + let start_group = 1; + let start_object = 1; + + let mut subscription = Subscription::new( + variable.track_alias, + variable.track_namespace, + variable.track_name, + variable.subscriber_priority, + variable.group_order, + variable.filter_type, + variable.start_group, + variable.start_object, + variable.end_group, + None, + ); + + subscription.set_actual_object_start(ObjectStart::new(start_group, start_object)); + + let result = subscription.get_actual_object_start().unwrap(); + + assert_eq!(result.group_id(), start_group); + assert_eq!(result.object_id(), start_object); + } + + #[test] + fn get_all_group_ids() { + let variable = test_helper_fn::common_subscription_variable(); + + let mut subscription = Subscription::new( + variable.track_alias, + variable.track_namespace, + variable.track_name, + variable.subscriber_priority, + variable.group_order, + variable.filter_type, + variable.start_group, + variable.start_object, + variable.end_group, + None, + ); + + let group_ids = vec![0, 1, 2]; + + subscription.set_stream_id(group_ids[0], 0, 0); + subscription.set_stream_id(group_ids[1], 0, 0); + subscription.set_stream_id(group_ids[2], 0, 0); + + let result = subscription.get_all_group_ids(); + + assert_eq!(result, group_ids); + } } diff --git a/moqt-core/src/modules/models/subscriptions/nodes/consumers.rs b/moqt-core/src/modules/models/subscriptions/nodes/consumers.rs index 82f5ed1d..4258db41 100644 --- a/moqt-core/src/modules/models/subscriptions/nodes/consumers.rs +++ b/moqt-core/src/modules/models/subscriptions/nodes/consumers.rs @@ -1,13 +1,13 @@ -use anyhow::{bail, Result}; -use std::collections::HashMap; - use crate::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, models::{ + range::{ObjectRange, ObjectStart}, subscriptions::{nodes::registry::SubscriptionNodeRegistry, Subscription}, tracks::ForwardingPreference, }, }; +use anyhow::{bail, Result}; +use std::collections::HashMap; type SubscribeId = u64; type TrackNamespace = Vec; @@ -47,7 +47,6 @@ impl SubscriptionNodeRegistry for Consumer { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<()> { // Subscriber cannot define forwarding preference until it receives object message. let subscription = Subscription::new( @@ -60,7 +59,6 @@ impl SubscriptionNodeRegistry for Consumer { start_group, start_object, end_group, - end_object, None, ); @@ -74,36 +72,40 @@ impl SubscriptionNodeRegistry for Consumer { Ok(self.subscriptions.get(&subscribe_id).cloned()) } - fn get_subscription_by_full_track_name( + fn get_subscribe_id( &self, track_namespace: TrackNamespace, track_name: String, - ) -> Result> { + ) -> Result> { Ok(self .subscriptions - .values() - .find(|subscription| { + .iter() + .find(|(_, subscription)| { subscription.get_track_namespace_and_name() == (track_namespace.clone(), track_name.clone()) }) - .cloned()) + .map(|(subscribe_id, _)| *subscribe_id)) } - fn get_subscribe_id( + fn get_track_alias(&self, subscribe_id: SubscribeId) -> Result> { + unimplemented!("subscribe_id: {}", subscribe_id) + } + + fn get_subscribe_id_by_track_alias( &self, - track_namespace: TrackNamespace, - track_name: String, + track_alias: TrackAlias, ) -> Result> { Ok(self .subscriptions .iter() - .find(|(_, subscription)| { - subscription.get_track_namespace_and_name() - == (track_namespace.clone(), track_name.clone()) - }) + .find(|(_, subscription)| subscription.get_track_alias() == track_alias) .map(|(subscribe_id, _)| *subscribe_id)) } + fn get_all_subscribe_ids(&self) -> Result> { + Ok(self.subscriptions.keys().cloned().collect()) + } + fn has_track(&self, track_namespace: TrackNamespace, track_name: String) -> bool { self.subscriptions.values().any(|subscription| { subscription.get_track_namespace_and_name() @@ -151,18 +153,81 @@ impl SubscriptionNodeRegistry for Consumer { Ok(forwarding_preference) } - fn get_filter_type(&self, subscribe_id: SubscribeId) -> Result { - unimplemented!("subscribe_id: {}", subscribe_id) + fn get_filter_type(&self, subscribe_id: SubscribeId) -> Result> { + let filter_type = self + .subscriptions + .get(&subscribe_id) + .map(|subscription| subscription.get_filter_type()); + Ok(filter_type) } - fn get_absolute_start(&self, subscribe_id: SubscribeId) -> Result<(Option, Option)> { - unimplemented!("subscribe_id: {}", subscribe_id) + fn get_requested_object_range(&self, subscribe_id: SubscribeId) -> Result> { + let requested_object_range = self + .subscriptions + .get(&subscribe_id) + .map(|subscription| subscription.get_requested_object_range()); + Ok(requested_object_range) + } + + fn set_actual_object_start( + &mut self, + subscribe_id: SubscribeId, + actual_object_start: ObjectStart, + ) -> Result<()> { + unimplemented!( + "subscribe_id: {}, actual_object_start: {:?}", + subscribe_id, + actual_object_start + ) } - fn get_absolute_end(&self, subscribe_id: SubscribeId) -> Result<(Option, Option)> { + fn get_actual_object_start(&self, subscribe_id: SubscribeId) -> Result> { unimplemented!("subscribe_id: {}", subscribe_id) } + fn set_stream_id( + &mut self, + subscribe_id: SubscribeId, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()> { + let subscription = self.subscriptions.get_mut(&subscribe_id).unwrap(); + subscription.set_stream_id(group_id, subgroup_id, stream_id); + + Ok(()) + } + + fn get_group_ids_for_subscription(&self, subscribe_id: SubscribeId) -> Result> { + let subscription = self.subscriptions.get(&subscribe_id).unwrap(); + let group_ids = subscription.get_all_group_ids(); + + Ok(group_ids) + } + + fn get_subgroup_ids_for_group( + &self, + subscribe_id: SubscribeId, + group_id: u64, + ) -> Result> { + let subscriprion = self.subscriptions.get(&subscribe_id).unwrap(); + let subgroup_ids = subscriprion.get_subgroup_ids_for_group(group_id); + + Ok(subgroup_ids) + } + + fn get_stream_id_for_subgroup( + &self, + subscribe_id: SubscribeId, + group_id: u64, + subgroup_id: u64, + ) -> Result> { + let subscription = self.subscriptions.get(&subscribe_id).unwrap(); + let stream_id = subscription.get_stream_id_for_subgroup(group_id, subgroup_id); + + Ok(stream_id) + } + fn is_subscribe_id_unique(&self, subscribe_id: SubscribeId) -> bool { !self.subscriptions.contains_key(&subscribe_id) } @@ -285,7 +350,6 @@ pub(crate) mod test_helper_fn { pub(crate) start_group: Option, pub(crate) start_object: Option, pub(crate) end_group: Option, - pub(crate) end_object: Option, } pub(crate) fn common_subscription_variable(subscribe_id: u64) -> SubscriptionVariables { @@ -299,7 +363,6 @@ pub(crate) mod test_helper_fn { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; SubscriptionVariables { consumer, @@ -313,7 +376,6 @@ pub(crate) mod test_helper_fn { start_group, start_object, end_group, - end_object, } } } @@ -347,7 +409,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); assert!(result.is_ok()); @@ -370,7 +431,6 @@ mod success { variables_clone.start_group, variables_clone.start_object, variables_clone.end_group, - variables_clone.end_object, ); let subscription = variables_clone @@ -388,7 +448,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, None, )); @@ -396,52 +455,36 @@ mod success { } #[test] - fn get_subscription_by_full_track_name() { + fn get_subscribe_id() { let subscribe_id = 0; - let variables = test_helper_fn::common_subscription_variable(subscribe_id); - - let mut variables_clone = variables.clone(); - let _ = variables_clone.consumer.set_subscription( - variables_clone.subscribe_id, - variables_clone.track_alias, - variables_clone.track_namespace, - variables_clone.track_name, - variables_clone.subscriber_priority, - variables_clone.group_order, - variables_clone.filter_type, - variables_clone.start_group, - variables_clone.start_object, - variables_clone.end_group, - variables_clone.end_object, - ); - - let subscription = variables_clone - .consumer - .get_subscription_by_full_track_name( - variables.track_namespace.clone(), - variables.track_name.clone(), - ) - .unwrap(); + let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); - let expected_subscription = Some(Subscription::new( + let _ = variables.consumer.set_subscription( + variables.subscribe_id, variables.track_alias, - variables.track_namespace, - variables.track_name, + variables.track_namespace.clone(), + variables.track_name.clone(), variables.subscriber_priority, variables.group_order, variables.filter_type, variables.start_group, variables.start_object, variables.end_group, - variables.end_object, - None, - )); + ); - assert_eq!(subscription, expected_subscription); + let expected_subscribe_id = variables.subscribe_id; + + let result_subscribe_id = variables + .consumer + .get_subscribe_id(variables.track_namespace, variables.track_name) + .unwrap() + .unwrap(); + + assert_eq!(result_subscribe_id, expected_subscribe_id); } #[test] - fn get_subscribe_id() { + fn get_subscribe_id_by_track_alias() { let subscribe_id = 0; let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); @@ -456,14 +499,13 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let expected_subscribe_id = variables.subscribe_id; let result_subscribe_id = variables .consumer - .get_subscribe_id(variables.track_namespace, variables.track_name) + .get_subscribe_id_by_track_alias(variables.track_alias) .unwrap() .unwrap(); @@ -486,7 +528,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -512,7 +553,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -539,7 +579,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables.consumer.is_requesting(variables.subscribe_id); @@ -563,7 +602,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -589,7 +627,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let forwarding_preference = ForwardingPreference::Datagram; @@ -617,7 +654,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let forwarding_preference = ForwardingPreference::Datagram; @@ -638,7 +674,6 @@ mod success { } #[test] - #[should_panic] fn get_filter_type() { let subscribe_id = 0; let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); @@ -654,46 +689,25 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); - let _ = variables + let result_filter_type = variables .consumer .get_filter_type(variables.subscribe_id) + .unwrap() .unwrap(); - } - #[test] - #[should_panic] - fn get_absolute_start() { - let subscribe_id = 0; - let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); - - let _ = variables.consumer.set_subscription( - variables.subscribe_id, - variables.track_alias, - variables.track_namespace.clone(), - variables.track_name.clone(), - variables.subscriber_priority, - variables.group_order, - variables.filter_type, - variables.start_group, - variables.start_object, - variables.end_group, - variables.end_object, - ); - let _ = variables - .consumer - .get_absolute_start(variables.subscribe_id) - .unwrap(); + assert_eq!(result_filter_type, variables.filter_type); } - #[test] - #[should_panic] - fn get_absolute_end() { + fn get_requested_object_range() { let subscribe_id = 0; let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); + let start_group = Some(0); + let start_object = Some(0); + let end_group = Some(1); + let _ = variables.consumer.set_subscription( variables.subscribe_id, variables.track_alias, @@ -702,16 +716,20 @@ mod success { variables.subscriber_priority, variables.group_order, variables.filter_type, - variables.start_group, - variables.start_object, - variables.end_group, - variables.end_object, + start_group, + start_object, + end_group, ); - let _ = variables + let result_range = variables .consumer - .get_absolute_end(variables.subscribe_id) + .get_requested_object_range(variables.subscribe_id) + .unwrap() .unwrap(); + + assert_eq!(result_range.start_group(), start_group); + assert_eq!(result_range.start_object(), start_object); + assert_eq!(result_range.end_group(), end_group); } #[test] @@ -730,7 +748,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -762,7 +779,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables.consumer.is_subscribe_id_less_than_max(9); @@ -789,7 +805,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -829,7 +844,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables diff --git a/moqt-core/src/modules/models/subscriptions/nodes/producers.rs b/moqt-core/src/modules/models/subscriptions/nodes/producers.rs index 7d7840b2..2e988e08 100644 --- a/moqt-core/src/modules/models/subscriptions/nodes/producers.rs +++ b/moqt-core/src/modules/models/subscriptions/nodes/producers.rs @@ -1,13 +1,13 @@ -use anyhow::{bail, Result}; -use std::collections::HashMap; - use crate::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, models::{ + range::{ObjectRange, ObjectStart}, subscriptions::{nodes::registry::SubscriptionNodeRegistry, Subscription}, tracks::ForwardingPreference, }, }; +use anyhow::{bail, Result}; +use std::collections::HashMap; type SubscribeId = u64; type TrackNamespace = Vec; @@ -45,7 +45,6 @@ impl SubscriptionNodeRegistry for Producer { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<()> { // Publisher can define forwarding preference when it publishes track. let subscription = Subscription::new( @@ -58,7 +57,6 @@ impl SubscriptionNodeRegistry for Producer { start_group, start_object, end_group, - end_object, None, ); @@ -71,33 +69,36 @@ impl SubscriptionNodeRegistry for Producer { Ok(self.subscriptions.get(&subscribe_id).cloned()) } - fn get_subscription_by_full_track_name( + fn get_subscribe_id( &self, track_namespace: TrackNamespace, track_name: String, - ) -> Result> { + ) -> Result> { Ok(self .subscriptions - .values() - .find(|subscription| { + .iter() + .find(|(_, subscription)| { subscription.get_track_namespace_and_name() == (track_namespace.clone(), track_name.clone()) }) - .cloned()) + .map(|(subscribe_id, _)| *subscribe_id)) } - fn get_subscribe_id( + fn get_track_alias(&self, subscribe_id: SubscribeId) -> Result> { + Ok(self + .subscriptions + .get(&subscribe_id) + .map(|subscription| subscription.get_track_alias())) + } + + fn get_subscribe_id_by_track_alias( &self, - track_namespace: TrackNamespace, - track_name: String, + track_alias: TrackAlias, ) -> Result> { Ok(self .subscriptions .iter() - .find(|(_, subscription)| { - subscription.get_track_namespace_and_name() - == (track_namespace.clone(), track_name.clone()) - }) + .find(|(_, subscription)| subscription.get_track_alias() == track_alias) .map(|(subscribe_id, _)| *subscribe_id)) } @@ -108,6 +109,10 @@ impl SubscriptionNodeRegistry for Producer { }) } + fn get_all_subscribe_ids(&self) -> Result> { + Ok(self.subscriptions.keys().cloned().collect()) + } + fn activate_subscription(&mut self, subscribe_id: SubscribeId) -> Result { let subscription = self.subscriptions.get_mut(&subscribe_id).unwrap(); let is_activated = subscription.activate(); @@ -145,34 +150,88 @@ impl SubscriptionNodeRegistry for Producer { unimplemented!("subscribe_id: {}", subscribe_id) } - fn get_filter_type(&self, subscribe_id: SubscribeId) -> Result { + fn get_filter_type(&self, subscribe_id: SubscribeId) -> Result> { let filter_type = self .subscriptions .get(&subscribe_id) - .map(|subscription| subscription.get_filter_type()) - .unwrap(); + .map(|subscription| subscription.get_filter_type()); Ok(filter_type) } - fn get_absolute_start(&self, subscribe_id: SubscribeId) -> Result<(Option, Option)> { - let (start_group, start_object) = self + fn get_requested_object_range(&self, subscribe_id: SubscribeId) -> Result> { + let requested_object_range = self .subscriptions .get(&subscribe_id) - .map(|subscription| subscription.get_absolute_start()) - .unwrap(); + .map(|subscription| subscription.get_requested_object_range()); - Ok((start_group, start_object)) + Ok(requested_object_range) } - fn get_absolute_end(&self, subscribe_id: SubscribeId) -> Result<(Option, Option)> { - let (end_group, end_object) = self + fn set_actual_object_start( + &mut self, + subscribe_id: SubscribeId, + actual_object_start: ObjectStart, + ) -> Result<()> { + self.subscriptions + .get_mut(&subscribe_id) + .unwrap() + .set_actual_object_start(actual_object_start); + + Ok(()) + } + + fn get_actual_object_start(&self, subscribe_id: SubscribeId) -> Result> { + let actual_object_start = self .subscriptions .get(&subscribe_id) - .map(|subscription| subscription.get_absolute_end()) + .map(|subscription| subscription.get_actual_object_start()) .unwrap(); - Ok((end_group, end_object)) + Ok(actual_object_start) + } + + fn set_stream_id( + &mut self, + subscribe_id: SubscribeId, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()> { + let subscription = self.subscriptions.get_mut(&subscribe_id).unwrap(); + subscription.set_stream_id(group_id, subgroup_id, stream_id); + + Ok(()) + } + + fn get_group_ids_for_subscription(&self, subscribe_id: SubscribeId) -> Result> { + let subscription = self.subscriptions.get(&subscribe_id).unwrap(); + let group_ids = subscription.get_all_group_ids(); + + Ok(group_ids) + } + + fn get_subgroup_ids_for_group( + &self, + subscribe_id: SubscribeId, + group_id: u64, + ) -> Result> { + let subscriprion = self.subscriptions.get(&subscribe_id).unwrap(); + let subgroup_ids = subscriprion.get_subgroup_ids_for_group(group_id); + + Ok(subgroup_ids) + } + + fn get_stream_id_for_subgroup( + &self, + subscribe_id: SubscribeId, + group_id: u64, + subgroup_id: u64, + ) -> Result> { + let subscription = self.subscriptions.get(&subscribe_id).unwrap(); + let stream_id = subscription.get_stream_id_for_subgroup(group_id, subgroup_id); + + Ok(stream_id) } fn is_subscribe_id_unique(&self, subscribe_id: SubscribeId) -> bool { @@ -335,6 +394,7 @@ pub(crate) mod test_helper_fn { #[cfg(test)] mod success { use crate::models::{ + range::ObjectStart, subscriptions::{ nodes::{ producers::{test_helper_fn, Producer}, @@ -361,7 +421,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); assert!(result.is_ok()); @@ -384,7 +443,6 @@ mod success { variables_clone.start_group, variables_clone.start_object, variables_clone.end_group, - variables_clone.end_object, ); let subscription = variables_clone @@ -402,7 +460,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, None, )); @@ -410,52 +467,65 @@ mod success { } #[test] - fn get_subscription_by_full_track_name() { + fn get_subscribe_id() { let subscribe_id = 0; - let variables = test_helper_fn::common_subscription_variable(subscribe_id); + let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); - let mut variables_clone = variables.clone(); - let _ = variables_clone.producer.set_subscription( - variables_clone.subscribe_id, - variables_clone.track_alias, - variables_clone.track_namespace, - variables_clone.track_name, - variables_clone.subscriber_priority, - variables_clone.group_order, - variables_clone.filter_type, - variables_clone.start_group, - variables_clone.start_object, - variables_clone.end_group, - variables_clone.end_object, + let _ = variables.producer.set_subscription( + variables.subscribe_id, + variables.track_alias, + variables.track_namespace.clone(), + variables.track_name.clone(), + variables.subscriber_priority, + variables.group_order, + variables.filter_type, + variables.start_group, + variables.start_object, + variables.end_group, ); - let subscription = variables_clone + let expected_subscribe_id = variables.subscribe_id; + + let result_subscribe_id = variables .producer - .get_subscription_by_full_track_name( - variables.track_namespace.clone(), - variables.track_name.clone(), - ) + .get_subscribe_id(variables.track_namespace, variables.track_name) + .unwrap() .unwrap(); - let expected_subscription = Some(Subscription::new( + assert_eq!(result_subscribe_id, expected_subscribe_id); + } + + #[test] + fn get_track_alias() { + let subscribe_id = 0; + let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); + + let _ = variables.producer.set_subscription( + variables.subscribe_id, variables.track_alias, - variables.track_namespace, - variables.track_name, + variables.track_namespace.clone(), + variables.track_name.clone(), variables.subscriber_priority, variables.group_order, variables.filter_type, variables.start_group, variables.start_object, variables.end_group, - variables.end_object, - None, - )); + ); - assert_eq!(subscription, expected_subscription); + let expected_track_alias = variables.track_alias; + + let result_track_alias = variables + .producer + .get_track_alias(variables.subscribe_id) + .unwrap() + .unwrap(); + + assert_eq!(result_track_alias, expected_track_alias); } #[test] - fn get_subscribe_id() { + fn get_subscribe_id_by_track_alias() { let subscribe_id = 0; let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); @@ -470,14 +540,13 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let expected_subscribe_id = variables.subscribe_id; let result_subscribe_id = variables .producer - .get_subscribe_id(variables.track_namespace, variables.track_name) + .get_subscribe_id_by_track_alias(variables.track_alias) .unwrap() .unwrap(); @@ -500,7 +569,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -526,7 +594,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -553,7 +620,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables.producer.is_requesting(variables.subscribe_id); @@ -577,7 +643,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -603,7 +668,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let forwarding_preference = ForwardingPreference::Datagram; @@ -632,7 +696,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let forwarding_preference = ForwardingPreference::Datagram; @@ -664,19 +727,19 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); - let result = variables + let result_filter_type = variables .producer .get_filter_type(variables.subscribe_id) + .unwrap() .unwrap(); - assert_eq!(result, variables.filter_type); + assert_eq!(result_filter_type, variables.filter_type); } #[test] - fn get_absolute_start() { + fn get_requested_object_range() { let subscribe_id = 0; let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); @@ -691,24 +754,54 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); - let result = variables + let result_range = variables .producer - .get_absolute_start(variables.subscribe_id) + .get_requested_object_range(variables.subscribe_id) + .unwrap() .unwrap(); - let expected_result = (variables.start_group, variables.start_object); + assert_eq!(result_range.start_group(), variables.start_group); + assert_eq!(result_range.start_object(), variables.start_object); + assert_eq!(result_range.end_group(), variables.end_group); + assert_eq!(result_range.end_object(), variables.end_object); + } - assert_eq!(result, expected_result); + #[test] + fn set_actual_object_start() { + let subscribe_id = 0; + let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); + + let _ = variables.producer.set_subscription( + variables.subscribe_id, + variables.track_alias, + variables.track_namespace.clone(), + variables.track_name.clone(), + variables.subscriber_priority, + variables.group_order, + variables.filter_type, + variables.start_group, + variables.start_object, + variables.end_group, + ); + + let actual_object_start = ObjectStart::new(0, 0); + + let result = variables + .producer + .set_actual_object_start(variables.subscribe_id, actual_object_start); + + assert!(result.is_ok()); } #[test] - fn get_absolute_end() { + fn get_actual_object_start() { let subscribe_id = 0; let mut variables = test_helper_fn::common_subscription_variable(subscribe_id); + let actual_object_start = ObjectStart::new(0, 0); + let _ = variables.producer.set_subscription( variables.subscribe_id, variables.track_alias, @@ -720,17 +813,19 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); + let _ = variables + .producer + .set_actual_object_start(variables.subscribe_id, actual_object_start.clone()); + let result = variables .producer - .get_absolute_end(variables.subscribe_id) + .get_actual_object_start(variables.subscribe_id) + .unwrap() .unwrap(); - let expected_result = (variables.end_group, variables.end_object); - - assert_eq!(result, expected_result); + assert_eq!(result, actual_object_start); } #[test] @@ -749,7 +844,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -781,7 +875,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -812,7 +905,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables @@ -844,7 +936,6 @@ mod success { variables.start_group, variables.start_object, variables.end_group, - variables.end_object, ); let result = variables.producer.create_valid_track_alias(); diff --git a/moqt-core/src/modules/models/subscriptions/nodes/registry.rs b/moqt-core/src/modules/models/subscriptions/nodes/registry.rs index 31075157..8bdd3e7b 100644 --- a/moqt-core/src/modules/models/subscriptions/nodes/registry.rs +++ b/moqt-core/src/modules/models/subscriptions/nodes/registry.rs @@ -1,9 +1,12 @@ -use anyhow::Result; - use crate::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, + models::{ + range::{ObjectRange, ObjectStart}, + subscriptions::Subscription, + tracks::ForwardingPreference, + }, }; +use anyhow::Result; type SubscribeId = u64; type TrackNamespace = Vec; @@ -23,19 +26,20 @@ pub trait SubscriptionNodeRegistry { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<()>; fn get_subscription(&self, subscribe_id: SubscribeId) -> Result>; - fn get_subscription_by_full_track_name( - &self, - track_namespace: TrackNamespace, - track_name: String, - ) -> Result>; + // TODO: Unify getter methods of subscribe_id fn get_subscribe_id( &self, track_namespace: TrackNamespace, track_name: String, ) -> Result>; + fn get_track_alias(&self, subscribe_id: SubscribeId) -> Result>; + fn get_subscribe_id_by_track_alias( + &self, + track_alias: TrackAlias, + ) -> Result>; + fn get_all_subscribe_ids(&self) -> Result>; fn has_track(&self, track_namespace: TrackNamespace, track_name: String) -> bool; fn activate_subscription(&mut self, subscribe_id: SubscribeId) -> Result; fn is_requesting(&self, subscribe_id: SubscribeId) -> bool; @@ -49,9 +53,33 @@ pub trait SubscriptionNodeRegistry { &self, subscribe_id: SubscribeId, ) -> Result>; - fn get_filter_type(&self, subscribe_id: SubscribeId) -> Result; - fn get_absolute_start(&self, subscribe_id: SubscribeId) -> Result<(Option, Option)>; - fn get_absolute_end(&self, subscribe_id: SubscribeId) -> Result<(Option, Option)>; + fn get_filter_type(&self, subscribe_id: SubscribeId) -> Result>; + fn set_stream_id( + &mut self, + subscribe_id: SubscribeId, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()>; + fn get_group_ids_for_subscription(&self, subscribe_id: SubscribeId) -> Result>; + fn get_subgroup_ids_for_group( + &self, + subscribe_id: SubscribeId, + group_id: u64, + ) -> Result>; + fn get_stream_id_for_subgroup( + &self, + subscribe_id: SubscribeId, + group_id: u64, + subgroup_id: u64, + ) -> Result>; + fn get_requested_object_range(&self, subscribe_id: SubscribeId) -> Result>; + fn set_actual_object_start( + &mut self, + subscribe_id: SubscribeId, + actual_object_start: ObjectStart, + ) -> Result<()>; + fn get_actual_object_start(&self, subscribe_id: SubscribeId) -> Result>; fn is_subscribe_id_unique(&self, subscribe_id: SubscribeId) -> bool; fn is_subscribe_id_less_than_max(&self, subscribe_id: SubscribeId) -> bool; diff --git a/moqt-core/src/modules/models/tracks.rs b/moqt-core/src/modules/models/tracks.rs index d06e0ce8..2914f50a 100644 --- a/moqt-core/src/modules/models/tracks.rs +++ b/moqt-core/src/modules/models/tracks.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; + +type GroupId = u64; +type StreamId = u64; +type SubgroupId = u64; + #[derive(Debug, Clone, PartialEq, Eq)] pub enum ForwardingPreference { Datagram, - Track, Subgroup, } @@ -11,6 +16,7 @@ pub struct Track { track_namespace: Vec, track_name: String, forwarding_preference: Option, + group_subgroup_stream_map: HashMap>, } impl Track { @@ -25,6 +31,7 @@ impl Track { track_namespace, track_name, forwarding_preference, + group_subgroup_stream_map: HashMap::new(), } } @@ -43,6 +50,31 @@ impl Track { pub fn get_track_alias(&self) -> u64 { self.track_alias } + + pub fn set_stream_id(&mut self, group_id: u64, subgroup_id: u64, stream_id: u64) { + self.group_subgroup_stream_map + .entry(group_id) + .or_default() + .insert(subgroup_id, stream_id); + } + + pub fn get_all_group_ids(&self) -> Vec { + self.group_subgroup_stream_map.keys().cloned().collect() + } + + pub fn get_subgroup_ids_for_group(&self, group_id: u64) -> Vec { + self.group_subgroup_stream_map + .get(&group_id) + .map(|subgroup_stream_map| subgroup_stream_map.keys().cloned().collect()) + .unwrap_or_default() + } + + pub fn get_stream_id_for_subgroup(&self, group_id: u64, subgroup_id: u64) -> Option { + self.group_subgroup_stream_map + .get(&group_id) + .and_then(|subgroup_stream_map| subgroup_stream_map.get(&subgroup_id)) + .cloned() + } } #[cfg(test)] diff --git a/moqt-core/src/modules/pubsub_relation_manager_repository.rs b/moqt-core/src/modules/pubsub_relation_manager_repository.rs index f7de3c28..ca1b53b2 100644 --- a/moqt-core/src/modules/pubsub_relation_manager_repository.rs +++ b/moqt-core/src/modules/pubsub_relation_manager_repository.rs @@ -1,10 +1,10 @@ -use anyhow::Result; -use async_trait::async_trait; - +use super::models::range::{ObjectRange, ObjectStart}; use crate::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, + models::tracks::ForwardingPreference, }; +use anyhow::Result; +use async_trait::async_trait; #[async_trait] pub trait PubSubRelationManagerRepository: Send + Sync { @@ -48,38 +48,34 @@ pub trait PubSubRelationManagerRepository: Send + Sync { track_alias: u64, downstream_session_id: usize, ) -> Result; - async fn is_track_existing( + async fn is_upstream_subscribed( &self, track_namespace: Vec, track_name: String, ) -> Result; - async fn get_upstream_subscription_by_full_track_name( - &self, - track_namespace: Vec, - track_name: String, - ) -> Result>; - async fn get_upstream_subscription_by_ids( - &self, - upstream_session_id: usize, - upstream_subscribe_id: u64, - ) -> Result>; - async fn get_downstream_subscription_by_ids( - &self, - downstream_session_id: usize, - downstream_subscribe_id: u64, - ) -> Result>; async fn get_upstream_session_id(&self, track_namespace: Vec) -> Result>; async fn get_requesting_downstream_session_ids_and_subscribe_ids( &self, upstream_subscribe_id: u64, upstream_session_id: usize, ) -> Result>>; + // TODO: Unify getter methods of subscribe_id async fn get_upstream_subscribe_id( &self, track_namespace: Vec, track_name: String, upstream_session_id: usize, ) -> Result>; + async fn get_downstream_track_alias( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result>; + async fn get_upstream_subscribe_id_by_track_alias( + &self, + upstream_session_id: usize, + upstream_track_alias: u64, + ) -> Result>; #[allow(clippy::too_many_arguments)] async fn set_downstream_subscription( &self, @@ -94,7 +90,6 @@ pub trait PubSubRelationManagerRepository: Send + Sync { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<()>; #[allow(clippy::too_many_arguments)] async fn set_upstream_subscription( @@ -108,7 +103,6 @@ pub trait PubSubRelationManagerRepository: Send + Sync { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<(u64, u64)>; async fn set_pubsub_relation( &self, @@ -180,6 +174,97 @@ pub trait PubSubRelationManagerRepository: Send + Sync { upstream_session_id: usize, upstream_subscribe_id: u64, ) -> Result>; + async fn get_upstream_filter_type( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result>; + async fn get_downstream_filter_type( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result>; + async fn get_upstream_requested_object_range( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result>; + async fn get_downstream_requested_object_range( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result>; + async fn set_downstream_actual_object_start( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + actual_object_start: ObjectStart, + ) -> Result<()>; + async fn get_downstream_actual_object_start( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result>; + async fn set_upstream_stream_id( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()>; + async fn get_upstream_subscribe_ids_for_client( + &self, + upstream_session_id: usize, + ) -> Result>; + async fn get_upstream_group_ids_for_subscription( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result>; + async fn get_upstream_subgroup_ids_for_group( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + ) -> Result>; + async fn get_upstream_stream_id_for_subgroup( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + ) -> Result>; + async fn set_downstream_stream_id( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()>; + async fn get_downstream_subscribe_ids_for_client( + &self, + downstream_session_id: usize, + ) -> Result>; + async fn get_downstream_group_ids_for_subscription( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result>; + async fn get_downstream_subgroup_ids_for_group( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + ) -> Result>; + async fn get_downstream_stream_id_for_subgroup( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + ) -> Result>; async fn get_related_subscribers( &self, upstream_session_id: usize, diff --git a/moqt-core/src/modules/send_stream_dispatcher_repository.rs b/moqt-core/src/modules/send_stream_dispatcher_repository.rs deleted file mode 100644 index 1d958cc4..00000000 --- a/moqt-core/src/modules/send_stream_dispatcher_repository.rs +++ /dev/null @@ -1,19 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; - -use crate::{constants::StreamDirection, messages::moqt_payload::MOQTPayload}; - -#[async_trait] -pub trait SendStreamDispatcherRepository: Send + Sync { - async fn broadcast_message_to_send_stream_threads( - &self, - session_id: Option, - message: Box, - ) -> Result<()>; - async fn transfer_message_to_send_stream_thread( - &self, - session_id: usize, - message: Box, - stream_direction: StreamDirection, - ) -> Result<()>; -} diff --git a/moqt-core/src/modules/variable_bytes.rs b/moqt-core/src/modules/variable_bytes.rs index 45e5b3ac..b7e1baab 100644 --- a/moqt-core/src/modules/variable_bytes.rs +++ b/moqt-core/src/modules/variable_bytes.rs @@ -6,84 +6,75 @@ use crate::variable_integer::{read_variable_integer, write_variable_integer}; // See https://datatracker.ietf.org/doc/html/draft-ietf-moq-transport-01#name-notational-conventions -pub fn read_variable_bytes_from_buffer(buf: &mut BytesMut) -> Result> { - //! this function is used for x (b) format. - //! x (b) : Indicates that x consists of a variable length integer, followed by that many bytes of binary data. +pub fn read_bytes_from_buffer(buf: &mut BytesMut, length: usize) -> Result> { + //! this function is used for x (A) format. + //! x (A): Indicates that x is A bits long let mut cur = Cursor::new(&buf[..]); - let ret = read_variable_bytes(&mut cur); + let ret = read_bytes(&mut cur, length); buf.advance(cur.position() as usize); ret } -pub fn read_variable_bytes(buf: &mut std::io::Cursor<&[u8]>) -> Result> { +pub fn read_bytes(buf: &mut std::io::Cursor<&[u8]>, length: usize) -> Result> { if buf.remaining() == 0 { bail!("buffer is empty in read_variable_bytes"); } - let len = read_variable_integer(buf)? as usize; - - if buf.remaining() < len { + if buf.remaining() < length { bail!( "buffer does not have enough length. actual: {}, expected: {}", buf.remaining() + 1, - len + 1 + length + 1 ) } - let value = buf.get_ref()[buf.position() as usize..buf.position() as usize + len] + let value = buf.get_ref()[buf.position() as usize..buf.position() as usize + length] .as_ref() .to_vec(); - buf.advance(value.len()); + buf.advance(length); Ok(value) } -pub fn read_fixed_length_bytes_from_buffer(buf: &mut BytesMut, length: usize) -> Result> { - //! this function is used for x (A) format. - //! x (A): Indicates that x is A bits long +pub fn read_variable_bytes_from_buffer(buf: &mut BytesMut) -> Result> { + //! this function is used for x (b) format. + //! x (b) : Indicates that x consists of a variable length integer, followed by that many bytes of binary data. let mut cur = Cursor::new(&buf[..]); - let ret = read_fixed_length_bytes(&mut cur, length); + let ret = read_variable_bytes(&mut cur); buf.advance(cur.position() as usize); ret } -pub fn read_fixed_length_bytes(buf: &mut std::io::Cursor<&[u8]>, length: usize) -> Result> { + +pub fn read_variable_bytes(buf: &mut std::io::Cursor<&[u8]>) -> Result> { if buf.remaining() == 0 { bail!("buffer is empty in read_variable_bytes"); } - if buf.remaining() < length { + let len = read_variable_integer(buf)? as usize; + + if buf.remaining() < len { bail!( "buffer does not have enough length. actual: {}, expected: {}", buf.remaining() + 1, - length + 1 + len + 1 ) } - let value = buf.get_ref()[buf.position() as usize..buf.position() as usize + length] + let value = buf.get_ref()[buf.position() as usize..buf.position() as usize + len] .as_ref() .to_vec(); - buf.advance(length); + buf.advance(value.len()); Ok(value) } -pub fn read_variable_bytes_to_end_from_buffer(buf: &mut BytesMut) -> Result> { - let mut cur = Cursor::new(&buf[..]); - - let ret = read_variable_bytes_to_end(&mut cur); - - buf.advance(cur.position() as usize); - - ret -} - -pub fn read_variable_bytes_to_end(buf: &mut std::io::Cursor<&[u8]>) -> Result> { +pub fn read_all_variable_bytes(buf: &mut std::io::Cursor<&[u8]>) -> Result> { if buf.remaining() == 0 { bail!("buffer is empty in read_variable_bytes"); } @@ -110,7 +101,7 @@ pub fn write_variable_bytes(value: &Vec) -> BytesMut { buf } -pub fn write_fixed_length_bytes(value: &Vec) -> BytesMut { +pub fn write_bytes(value: &Vec) -> BytesMut { //! this function is used for x (A) format. //! x (A): Indicates that x is A bits long let mut buf = BytesMut::with_capacity(0); @@ -119,7 +110,7 @@ pub fn write_fixed_length_bytes(value: &Vec) -> BytesMut { buf } -pub fn convert_bytes_to_integer(value: Vec) -> Result { +pub fn bytes_to_integer(value: Vec) -> Result { let mut ret = 0; for (i, &byte) in value.iter().enumerate() { ret |= (byte as u64) << (i * 8); diff --git a/moqt-core/src/modules/variable_integer.rs b/moqt-core/src/modules/variable_integer.rs index 526f6816..2a6989d4 100644 --- a/moqt-core/src/modules/variable_integer.rs +++ b/moqt-core/src/modules/variable_integer.rs @@ -21,7 +21,7 @@ pub fn read_variable_integer(buf: &mut std::io::Cursor<&[u8]>) -> Result { let first_byte = buf.get_u8(); let mut value: u64 = (first_byte % 64).into(); - let rest_len = get_length_from_variable_integer_first_byte(first_byte) - 1; + let rest_len = get_2msb_length_from_first_byte(first_byte) - 1; if buf.remaining() < rest_len.into() { bail!( @@ -39,7 +39,7 @@ pub fn read_variable_integer(buf: &mut std::io::Cursor<&[u8]>) -> Result { Ok(value) } -pub fn get_length_from_variable_integer_first_byte(first_byte: u8) -> u8 { +pub fn get_2msb_length_from_first_byte(first_byte: u8) -> u8 { let msb2 = first_byte / 64; // 2MSB Length @@ -50,6 +50,18 @@ pub fn get_length_from_variable_integer_first_byte(first_byte: u8) -> u8 { 2usize.pow(msb2 as u32) as u8 } +pub fn get_2msb_value(value: u64) -> u64 { + let first_byte = (value & 0xFF) as u8; // 0xFF: Bit mask to get the first byte + let first_two_bits = first_byte / 64; + match first_two_bits { + 0b00 => value & 0x3F, // 8ビットの先頭2bitを無視 + 0b01 => value & 0x3FFF, // 16ビットの先頭2bitを無視 + 0b10 => value & 0x3FFFFFFF, // 32ビットの先頭2bitを無視 + 0b11 => value & 0x3FFFFFFFFFFFFFFF, // 64ビットの先頭2bitを無視 + _ => unreachable!(), + } +} + pub fn write_variable_integer(value: u64) -> BytesMut { let mut buf = BytesMut::with_capacity(0); diff --git a/moqt-server/Cargo.toml b/moqt-server/Cargo.toml index 9debf3ed..ac656e40 100644 --- a/moqt-server/Cargo.toml +++ b/moqt-server/Cargo.toml @@ -11,7 +11,7 @@ moqt-core = {path = "../moqt-core"} tokio = { version = "1.32.0", features = ["rt-multi-thread"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -wtransport = "0.1.12" +wtransport = { version = "0.5.0", features = ["quinn"] } bytes = "1" async-trait = "0.1.74" ttl_cache = "0.5.1" \ No newline at end of file diff --git a/moqt-server/src/lib.rs b/moqt-server/src/lib.rs index 02493424..ce9d6615 100644 --- a/moqt-server/src/lib.rs +++ b/moqt-server/src/lib.rs @@ -2,21 +2,23 @@ use anyhow::{bail, Context, Result}; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::sync::{mpsc, mpsc::Sender, Mutex}; use tracing::{self, Instrument}; +use wtransport::quinn::{TransportConfig, VarInt}; use wtransport::{Endpoint, Identity, ServerConfig}; mod modules; pub use modules::config::MOQTConfig; use modules::{ buffer_manager::{buffer_manager, BufferCommand}, + control_message_dispatcher::{control_message_dispatcher, ControlMessageDispatchCommand}, logging::init_logging, object_cache_storage::{ cache::SubgroupStreamId, commands::ObjectCacheStorageCommand, storage::object_cache_storage, }, pubsub_relation_manager::{commands::PubSubRelationCommand, manager::pubsub_relation_manager}, - send_stream_dispatcher::{send_stream_dispatcher, SendStreamDispatchCommand}, server_processes::{ senders::{SenderToOtherConnectionThread, SendersToManagementThread}, session_handler::SessionHandler, }, + signal_dispatcher, }; pub use moqt_core::constants; use moqt_core::{ @@ -24,6 +26,8 @@ use moqt_core::{ data_stream_type::DataStreamType, }; +use crate::signal_dispatcher::{signal_dispatcher, SignalDispatchCommand}; + type SubscribeId = u64; pub(crate) type SenderToOpenSubscription = Sender<(SubscribeId, DataStreamType, Option)>; @@ -55,10 +59,15 @@ impl MOQTServer { if self.underlay != UnderlayType::WebTransport { bail!("Underlay must be WebTransport, not {:?}", self.underlay); } - let config = ServerConfig::builder() + let mut transport_config = TransportConfig::default(); + transport_config.max_concurrent_uni_streams(100000u32.into()); // 単方向ストリーム数を100000に設定 + transport_config.time_threshold(1.5); + transport_config.packet_threshold(5); + transport_config.stream_receive_window(VarInt::from_u32(10 * 1024 * 1024)); // initial_max_stream_data_uniと同義。デフォルトは65,536 バイト (64KB)なので1MBにする + let mut config = ServerConfig::builder() .with_bind_default(self.port) - .with_identity( - &Identity::load_pemfiles(&self.cert_path, &self.key_path) + .with_custom_transport( + Identity::load_pemfiles(&self.cert_path, &self.key_path) .await .with_context(|| { format!( @@ -66,9 +75,11 @@ impl MOQTServer { self.cert_path, self.key_path ) })?, + transport_config, ) .keep_alive_interval(Some(Duration::from_secs(self.keep_alive_interval_sec))) .build(); + let _ = config.quic_endpoint_config_mut().max_udp_payload_size(1000); let server = Endpoint::server(config)?; tracing::info!("Server ready!"); @@ -78,8 +89,15 @@ impl MOQTServer { let (pubsub_relation_tx, mut pubsub_relation_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut pubsub_relation_rx).await }); - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let (signal_dispatch_tx, mut signal_dispatch_rx) = + mpsc::channel::(1024); + tokio::spawn(async move { signal_dispatcher(&mut signal_dispatch_rx).await }); + let (object_cache_tx, mut object_cache_rx) = mpsc::channel::(1024); tokio::spawn(async move { object_cache_storage(&mut object_cache_rx).await }); @@ -93,7 +111,8 @@ impl MOQTServer { let senders_to_management_thread = SendersToManagementThread::new( buffer_tx.clone(), pubsub_relation_tx.clone(), - send_stream_tx.clone(), + control_message_dispatch_tx.clone(), + signal_dispatch_tx.clone(), object_cache_tx.clone(), ); diff --git a/moqt-server/src/modules.rs b/moqt-server/src/modules.rs index 630eaba8..ea62c45a 100644 --- a/moqt-server/src/modules.rs +++ b/moqt-server/src/modules.rs @@ -1,9 +1,10 @@ pub mod buffer_manager; pub(crate) mod config; +pub(crate) mod control_message_dispatcher; pub(crate) mod logging; pub(crate) mod message_handlers; pub(crate) mod moqt_client; pub(crate) mod object_cache_storage; pub(crate) mod pubsub_relation_manager; -pub(crate) mod send_stream_dispatcher; pub(crate) mod server_processes; +pub(crate) mod signal_dispatcher; diff --git a/moqt-server/src/modules/control_message_dispatcher.rs b/moqt-server/src/modules/control_message_dispatcher.rs new file mode 100644 index 00000000..af5e9f76 --- /dev/null +++ b/moqt-server/src/modules/control_message_dispatcher.rs @@ -0,0 +1,92 @@ +use anyhow::Result; +use moqt_core::messages::moqt_payload::MOQTPayload; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::{mpsc, oneshot}; +type SenderToControlMessageSenderThread = mpsc::Sender>>; + +#[derive(Debug)] +pub(crate) enum ControlMessageDispatchCommand { + Set { + session_id: usize, + sender: SenderToControlMessageSenderThread, + }, + Get { + session_id: usize, + resp: oneshot::Sender>, + }, + Delete { + session_id: usize, + }, +} + +pub(crate) async fn control_message_dispatcher( + rx: &mut mpsc::Receiver, +) { + tracing::trace!("control_message_dispatcher start"); + // { + // "${session_id}" : tx, + // } + // } + let mut dispatcher = HashMap::::new(); + + while let Some(cmd) = rx.recv().await { + tracing::debug!("command received: {:#?}", cmd); + match cmd { + ControlMessageDispatchCommand::Set { session_id, sender } => { + dispatcher.insert(session_id, sender); + tracing::debug!("set: {:?}", session_id); + } + ControlMessageDispatchCommand::Get { session_id, resp } => { + let sender = dispatcher.get(&session_id).cloned(); + tracing::debug!("get: {:?}", sender); + let _ = resp.send(sender); + } + ControlMessageDispatchCommand::Delete { session_id } => { + dispatcher.remove(&session_id); + tracing::debug!("delete: {:?}", session_id); + } + } + } + + tracing::trace!("control_message_dispatcher end"); +} + +#[derive(Clone)] +pub(crate) struct ControlMessageDispatcher { + tx: mpsc::Sender, +} + +impl ControlMessageDispatcher { + pub fn new(tx: mpsc::Sender) -> Self { + Self { tx } + } + + // Used for testing in unsubscribe_handler + #[allow(dead_code)] + pub fn get_tx(&self) -> mpsc::Sender { + self.tx.clone() + } +} + +impl ControlMessageDispatcher { + pub(crate) async fn transfer_message_to_control_message_sender_thread( + &self, + session_id: usize, + message: Box, + ) -> Result<()> { + let (resp_tx, resp_rx) = oneshot::channel::>(); + + let cmd = ControlMessageDispatchCommand::Get { + session_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let sender = resp_rx + .await? + .ok_or_else(|| anyhow::anyhow!("sender not found"))?; + let message_arc = Arc::new(message); + let _ = sender.send(message_arc).await; + Ok(()) + } +} diff --git a/moqt-server/src/modules/message_handlers.rs b/moqt-server/src/modules/message_handlers.rs index 69d12dfa..93896d97 100644 --- a/moqt-server/src/modules/message_handlers.rs +++ b/moqt-server/src/modules/message_handlers.rs @@ -1,4 +1,4 @@ pub(crate) mod control_message; pub(crate) mod datagram_object; -pub(crate) mod stream_header; -pub(crate) mod stream_object; +pub(crate) mod subgroup_stream_header; +pub(crate) mod subgroup_stream_object; diff --git a/moqt-server/src/modules/message_handlers/control_message.rs b/moqt-server/src/modules/message_handlers/control_message.rs index 43fd2449..48b86d28 100644 --- a/moqt-server/src/modules/message_handlers/control_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message.rs @@ -2,6 +2,7 @@ pub(crate) mod handlers; pub(crate) mod server_processes; use crate::constants::TerminationErrorCode; +use crate::modules::control_message_dispatcher::ControlMessageDispatcher; use crate::modules::moqt_client::MOQTClient; use crate::modules::{ message_handlers::control_message::{ @@ -11,9 +12,9 @@ use crate::modules::{ announce_message::process_announce_message, announce_ok_message::process_announce_ok_message, client_setup_message::process_client_setup_message, + subscribe_announces_message::process_subscribe_announces_message, subscribe_error_message::process_subscribe_error_message, subscribe_message::process_subscribe_message, - subscribe_namespace_message::process_subscribe_namespace_message, subscribe_ok_message::process_subscribe_ok_message, }, }, @@ -29,7 +30,6 @@ use moqt_core::{ messages::{control_messages::unannounce::UnAnnounce, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, variable_integer::{read_variable_integer, write_variable_integer}, - SendStreamDispatcherRepository, }; use server_processes::unsubscribe_message::process_unsubscribe_message; use std::{collections::HashMap, io::Cursor, sync::Arc}; @@ -66,7 +66,7 @@ pub async fn control_message_handler( client: &mut MOQTClient, start_forwarder_txes: Arc>>, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, object_cache_storage: &mut ObjectCacheStorageWrapper, ) -> MessageProcessResult { tracing::trace!("control_message_handler! {}", read_buf.len()); @@ -141,7 +141,7 @@ pub async fn control_message_handler( client, &mut write_buf, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, object_cache_storage, start_forwarder_txes, ) @@ -166,7 +166,7 @@ pub async fn control_message_handler( match process_subscribe_ok_message( &mut payload_buf, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, client, ) .await @@ -186,7 +186,7 @@ pub async fn control_message_handler( match process_subscribe_error_message( &mut payload_buf, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, client, ) .await @@ -206,7 +206,7 @@ pub async fn control_message_handler( match process_unsubscribe_message( &mut payload_buf, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, client, ) .await @@ -228,7 +228,7 @@ pub async fn control_message_handler( client, &mut write_buf, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, ) .await { @@ -278,18 +278,18 @@ pub async fn control_message_handler( } } } - ControlMessageType::SubscribeNamespace => { - match process_subscribe_namespace_message( + ControlMessageType::SubscribeAnnounces => { + match process_subscribe_announces_message( &mut payload_buf, client, &mut write_buf, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, ) .await { Ok(result) => match result { - Some(_) => ControlMessageType::SubscribeNamespaceError, + Some(_) => ControlMessageType::SubscribeAnnouncesError, None => { return MessageProcessResult::SuccessWithoutResponse; } @@ -350,6 +350,9 @@ pub async fn control_message_handler( #[cfg(test)] pub(crate) mod test_helper_fn { use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, message_handlers::control_message::{control_message_handler, MessageProcessResult}, moqt_client::{MOQTClient, MOQTClientStatus}, object_cache_storage::{ @@ -360,9 +363,6 @@ pub(crate) mod test_helper_fn { commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use crate::SenderToOpenSubscription; @@ -394,12 +394,15 @@ pub(crate) mod test_helper_fn { let mut pubsub_relation_manager: PubSubRelationManagerWrapper = PubSubRelationManagerWrapper::new(track_namespace_tx); - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -418,7 +421,7 @@ pub(crate) mod test_helper_fn { &mut client, start_forwarder_txes, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, ) .await @@ -439,11 +442,11 @@ mod success { let bytes_array = [ 1, // Number of Supported Versions (i) 192, // Supported Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 2, // SETUP Parameters (..): Role(Subscriber) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff00000a) in 62bit + 1, // Number of Parameters (i) + 0, // SETUP Parameters (..): Type(Role) + 1, // SETUP Parameters (..): Length + 2, // SETUP Parameters (..): Role(Subscriber) ]; let client_status = MOQTClientStatus::Connected; @@ -466,11 +469,11 @@ mod success { let bytes_array = [ 1, // Number of Supported Versions (i) 192, // Supported Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 2, // SETUP Parameters (..): Role(Subscriber) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff00000a) in 62bit + 1, // Number of Parameters (i) + 0, // SETUP Parameters (..): Type(Role) + 1, // SETUP Parameters (..): Length + 2, // SETUP Parameters (..): Role(Subscriber) ]; let client_status = MOQTClientStatus::Connected; @@ -502,11 +505,11 @@ mod failure { let bytes_array = [ 1, // Number of Supported Versions (i) 192, // Supported Version (i): Length(11 of 2MSB) - 0, 0, 0, 255, 0, 0, 6, // Supported Version(i): Value(0xff000006) in 62bit - 1, // Number of Parameters (i) - 0, // SETUP Parameters (..): Type(Role) - 1, // SETUP Parameters (..): Length - 2, // SETUP Parameters (..): Role(Subscriber) + 0, 0, 0, 255, 0, 0, 10, // Supported Version(i): Value(0xff000a) in 62bit + 1, // Number of Parameters (i) + 0, // SETUP Parameters (..): Type(Role) + 1, // SETUP Parameters (..): Length + 2, // SETUP Parameters (..): Role(Subscriber) ]; let wrong_client_status = MOQTClientStatus::SetUp; // Correct Status is Connected diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers.rs b/moqt-server/src/modules/message_handlers/control_message/handlers.rs index be22dfa4..0611d2c6 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers.rs @@ -1,9 +1,9 @@ pub(crate) mod announce_handler; pub(crate) mod announce_ok_handler; pub(crate) mod server_setup_handler; +pub(crate) mod subscribe_announces_handler; pub(crate) mod subscribe_error_handler; pub(crate) mod subscribe_handler; -pub(crate) mod subscribe_namespace_handler; pub(crate) mod subscribe_ok_handler; pub(crate) mod unannounce_handler; pub(crate) mod unsubscribe_handler; diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/announce_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/announce_handler.rs index 31d0f590..7a5de102 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/announce_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/announce_handler.rs @@ -1,17 +1,17 @@ -use crate::modules::moqt_client::MOQTClient; +use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, moqt_client::MOQTClient, +}; use anyhow::Result; use moqt_core::{ - constants::StreamDirection, messages::control_messages::{ announce::Announce, announce_error::AnnounceError, announce_ok::AnnounceOk, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; async fn forward_announce_to_subscribers( pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, track_namespace: Vec, ) -> Result<()> { let downstream_session_ids = match pubsub_relation_manager_repository @@ -33,11 +33,10 @@ async fn forward_announce_to_subscribers( Ok(true) => {} Ok(false) => { let announce_message = Box::new(Announce::new(track_namespace.clone(), vec![])); - let _ = send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + let _ = control_message_dispatcher + .transfer_message_to_control_message_sender_thread( downstream_session_id, announce_message, - StreamDirection::Bi, ) .await; } @@ -54,7 +53,7 @@ pub(crate) async fn announce_handler( announce_message: Announce, client: &MOQTClient, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, ) -> Result> { tracing::trace!("announce_handler start."); tracing::debug!("announce_message: {:#?}", announce_message); @@ -71,18 +70,14 @@ pub(crate) async fn announce_handler( // TODO: Unify the method to send a message to the opposite client itself let announce_ok_message = Box::new(AnnounceOk::new(track_namespace.clone())); - let _ = send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( - client.id(), - announce_ok_message, - StreamDirection::Bi, - ) + let _ = control_message_dispatcher + .transfer_message_to_control_message_sender_thread(client.id(), announce_ok_message) .await; - // If subscribers already sent SUBSCRIBE_NAMESPACE, send ANNOUNCE message to them + // If subscribers already sent SUBSCRIBE_ANNOUNCES, send ANNOUNCE message to them match forward_announce_to_subscribers( pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, track_namespace.clone(), ) .await @@ -108,17 +103,16 @@ pub(crate) async fn announce_handler( #[cfg(test)] mod success { use super::announce_handler; + use crate::modules::control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }; use crate::modules::moqt_client::MOQTClient; use crate::modules::pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }; - use crate::modules::send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }; use crate::modules::server_processes::senders; use bytes::BytesMut; - use moqt_core::constants::StreamDirection; use moqt_core::messages::control_messages::{ announce::Announce, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, @@ -173,25 +167,26 @@ mod success { ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx.clone(), }) .await; - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -201,7 +196,7 @@ mod success { announce_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await .unwrap(); @@ -257,18 +252,20 @@ mod success { .set_downstream_announced_namespace(track_namespace.clone(), downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -278,7 +275,7 @@ mod success { announce_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await .unwrap(); @@ -290,17 +287,16 @@ mod success { #[cfg(test)] mod failure { use super::announce_handler; + use crate::modules::control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }; use crate::modules::moqt_client::MOQTClient; use crate::modules::pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }; - use crate::modules::send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }; use crate::modules::server_processes::senders; use bytes::BytesMut; - use moqt_core::constants::StreamDirection; use moqt_core::messages::control_messages::{ announce::Announce, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, @@ -350,18 +346,20 @@ mod failure { ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -371,7 +369,7 @@ mod failure { announce_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await; diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/server_setup_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/server_setup_handler.rs index 38c65d83..e1b23b54 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/server_setup_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/server_setup_handler.rs @@ -2,11 +2,8 @@ use anyhow::{bail, Result}; use moqt_core::{ constants::UnderlayType, messages::control_messages::{ - client_setup::ClientSetup, - server_setup::ServerSetup, - setup_parameters::MaxSubscribeID, + client_setup::ClientSetup, server_setup::ServerSetup, setup_parameters::MaxSubscribeID, setup_parameters::SetupParameter, - setup_parameters::{Role, RoleCase}, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, }; @@ -23,16 +20,12 @@ fn is_requested_version_supported(supported_versions: Vec) -> bool { } fn handle_setup_parameter( - client: &mut MOQTClient, setup_parameters: Vec, underlay_type: UnderlayType, ) -> Result { let mut max_subscribe_id: u64 = 0; for setup_parameter in &setup_parameters { match setup_parameter { - SetupParameter::Role(param) => { - client.set_role(param.value)?; - } SetupParameter::Path(_) => { if underlay_type == UnderlayType::WebTransport { bail!("PATH parameter is not allowed on WebTransport."); @@ -55,51 +48,20 @@ async fn setup_subscription_node( upstream_max_subscribe_id: u64, downstream_max_subscribe_id: u64, ) -> Result<()> { - match client.role() { - Some(RoleCase::Publisher) => { - // Generate consumer that manages namespaces and subscriptions with producers. - pubsub_relation_manager_repository - .setup_publisher(upstream_max_subscribe_id, client.id()) - .await?; - } - Some(RoleCase::Subscriber) => { - // Generate producer that manages namespaces and subscriptions with subscribers. - pubsub_relation_manager_repository - .setup_subscriber(downstream_max_subscribe_id, client.id()) - .await?; - } - Some(RoleCase::PubSub) => { - // Generate producer and consumer that manages namespaces and subscriptions with publishers and subscribers. - pubsub_relation_manager_repository - .setup_publisher(upstream_max_subscribe_id, client.id()) - .await?; - pubsub_relation_manager_repository - .setup_subscriber(downstream_max_subscribe_id, client.id()) - .await?; - } - None => { - bail!("Role parameter is required in SETUP parameter from client."); - } - } + // Generate producer and consumer that manages namespaces and subscriptions with publishers and subscribers. + pubsub_relation_manager_repository + .setup_publisher(upstream_max_subscribe_id, client.id()) + .await?; + pubsub_relation_manager_repository + .setup_subscriber(downstream_max_subscribe_id, client.id()) + .await?; + Ok(()) } -fn create_setup_parameters( - client: &MOQTClient, - downstream_max_subscribe_id: u64, -) -> Vec { +fn create_setup_parameters(downstream_max_subscribe_id: u64) -> Vec { let mut setup_parameters = vec![]; - // Create a setup parameter with role set to 3 and assign it. - // Normally, the server should determine the role here, but for now, let's set it to 3. - let role_parameter = SetupParameter::Role(Role::new(RoleCase::PubSub)); - setup_parameters.push(role_parameter); - - if client.role() == Some(RoleCase::Subscriber) || client.role() == Some(RoleCase::PubSub) { - let max_subscribe_id_parameter = - SetupParameter::MaxSubscribeID(MaxSubscribeID::new(downstream_max_subscribe_id)); - setup_parameters.push(max_subscribe_id_parameter); - } let max_subscribe_id_parameter = SetupParameter::MaxSubscribeID(MaxSubscribeID::new(downstream_max_subscribe_id)); setup_parameters.push(max_subscribe_id_parameter); @@ -114,21 +76,14 @@ pub(crate) async fn setup_handler( pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, ) -> Result { tracing::trace!("setup_handler start."); - tracing::debug!("client_setup_message: {:#?}", client_setup_message); - - tracing::debug!( - "supported_versions: {:#x?}", - client_setup_message.supported_versions - ); - if !is_requested_version_supported(client_setup_message.supported_versions) { bail!("Supported version is not included"); } // If max_subscribe_id is not included in the CLIENT_SETUP message, upstream_max_subscribe_id is set to 0. let upstream_max_subscribe_id: u64 = - handle_setup_parameter(client, client_setup_message.setup_parameters, underlay_type)?; + handle_setup_parameter(client_setup_message.setup_parameters, underlay_type)?; // FIXME: downstream_max_subscribe_id for subscriber is fixed at 100 for now. let downstream_max_subscribe_id: u64 = 100; @@ -141,7 +96,7 @@ pub(crate) async fn setup_handler( ) .await?; - let setup_parameters = create_setup_parameters(client, downstream_max_subscribe_id); + let setup_parameters = create_setup_parameters(downstream_max_subscribe_id); let server_setup_message = ServerSetup::new(constants::MOQ_TRANSPORT_VERSION, setup_parameters); // State: Connected -> Setup client.update_status(MOQTClientStatus::SetUp); @@ -165,47 +120,16 @@ mod success { }; use moqt_core::messages::control_messages::{ client_setup::ClientSetup, - setup_parameters::{Path, Role, RoleCase, SetupParameter}, + setup_parameters::{Path, SetupParameter}, }; use std::vec; use tokio::sync::mpsc; #[tokio::test] - async fn only_role() { - let senders_mock = senders::test_helper_fn::create_senders_mock(); - let mut client = MOQTClient::new(33, senders_mock); - let setup_parameters = vec![SetupParameter::Role(Role::new(RoleCase::Publisher))]; - let client_setup_message = - ClientSetup::new(vec![constants::MOQ_TRANSPORT_VERSION], setup_parameters); - let underlay_type = crate::constants::UnderlayType::WebTransport; - - // Generate PubSubRelationManagerWrapper - let (track_namespace_tx, mut track_namespace_rx) = - mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_namespace_rx).await }); - let mut pubsub_relation_manager: PubSubRelationManagerWrapper = - PubSubRelationManagerWrapper::new(track_namespace_tx); - - let server_setup_message = setup_handler( - client_setup_message, - underlay_type, - &mut client, - &mut pubsub_relation_manager, - ) - .await; - - assert!(server_setup_message.is_ok()); - let _server_setup_message = server_setup_message.unwrap(); // TODO: Not implemented yet - } - - #[tokio::test] - async fn role_and_path_on_quic() { + async fn path_on_quic() { let senders_mock = senders::test_helper_fn::create_senders_mock(); let mut client = MOQTClient::new(33, senders_mock); - let setup_parameters = vec![ - SetupParameter::Role(Role::new(RoleCase::Publisher)), - SetupParameter::Path(Path::new(String::from("test"))), - ]; + let setup_parameters = vec![SetupParameter::Path(Path::new(String::from("test")))]; let client_setup_message = ClientSetup::new(vec![constants::MOQ_TRANSPORT_VERSION], setup_parameters); let underlay_type = crate::constants::UnderlayType::QUIC; @@ -242,38 +166,11 @@ mod failure { use crate::modules::server_processes::senders; use moqt_core::messages::control_messages::{ client_setup::ClientSetup, - setup_parameters::{Path, Role, RoleCase, SetupParameter}, + setup_parameters::{Path, SetupParameter}, }; use std::vec; use tokio::sync::mpsc; - #[tokio::test] - async fn no_role_parameter() { - let senders_mock = senders::test_helper_fn::create_senders_mock(); - let mut client = MOQTClient::new(33, senders_mock); - let setup_parameters = vec![]; - let client_setup_message = - ClientSetup::new(vec![constants::MOQ_TRANSPORT_VERSION], setup_parameters); - let underlay_type = crate::constants::UnderlayType::WebTransport; - - // Generate PubSubRelationManagerWrapper - let (track_namespace_tx, mut track_namespace_rx) = - mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_namespace_rx).await }); - let mut pubsub_relation_manager: PubSubRelationManagerWrapper = - PubSubRelationManagerWrapper::new(track_namespace_tx); - - let server_setup_message = setup_handler( - client_setup_message, - underlay_type, - &mut client, - &mut pubsub_relation_manager, - ) - .await; - - assert!(server_setup_message.is_err()); - } - #[tokio::test] async fn include_path_on_wt() { let senders_mock = senders::test_helper_fn::create_senders_mock(); @@ -301,38 +198,11 @@ mod failure { assert!(server_setup_message.is_err()); } - #[tokio::test] - async fn include_only_path_on_quic() { - let senders_mock = senders::test_helper_fn::create_senders_mock(); - let mut client = MOQTClient::new(33, senders_mock); - let setup_parameters = vec![SetupParameter::Path(Path::new(String::from("test")))]; - let client_setup_message = - ClientSetup::new(vec![constants::MOQ_TRANSPORT_VERSION], setup_parameters); - let underlay_type = crate::constants::UnderlayType::QUIC; - - // Generate PubSubRelationManagerWrapper - let (track_namespace_tx, mut track_namespace_rx) = - mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_namespace_rx).await }); - let mut pubsub_relation_manager: PubSubRelationManagerWrapper = - PubSubRelationManagerWrapper::new(track_namespace_tx); - - let server_setup_message = setup_handler( - client_setup_message, - underlay_type, - &mut client, - &mut pubsub_relation_manager, - ) - .await; - - assert!(server_setup_message.is_err()); - } - #[tokio::test] async fn include_unsupported_version() { let senders_mock = senders::test_helper_fn::create_senders_mock(); let mut client = MOQTClient::new(33, senders_mock); - let setup_parameters = vec![SetupParameter::Role(Role::new(RoleCase::Subscriber))]; + let setup_parameters = vec![SetupParameter::Path(Path::new(String::from("test")))]; let unsupported_version = 8888; let client_setup_message = ClientSetup::new(vec![unsupported_version], setup_parameters); @@ -361,10 +231,7 @@ mod failure { async fn include_unknown_parameter() { let senders_mock = senders::test_helper_fn::create_senders_mock(); let mut client = MOQTClient::new(33, senders_mock); - let setup_parameters = vec![ - SetupParameter::Role(Role::new(RoleCase::Publisher)), - SetupParameter::Unknown(0), - ]; + let setup_parameters = vec![SetupParameter::Unknown(0)]; let client_setup_message = ClientSetup::new(vec![constants::MOQ_TRANSPORT_VERSION], setup_parameters); let underlay_type = crate::constants::UnderlayType::WebTransport; diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_namespace_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_announces_handler.rs similarity index 66% rename from moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_namespace_handler.rs rename to moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_announces_handler.rs index d4c9b6e7..176a774d 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_namespace_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_announces_handler.rs @@ -1,36 +1,34 @@ +use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, moqt_client::MOQTClient, +}; use anyhow::Result; - use moqt_core::{ - constants::StreamDirection, messages::{ control_messages::{ - announce::Announce, subscribe_namespace::SubscribeNamespace, - subscribe_namespace_error::SubscribeNamespaceError, - subscribe_namespace_ok::SubscribeNamespaceOk, + announce::Announce, subscribe_announces::SubscribeAnnounces, + subscribe_announces_error::SubscribeAnnouncesError, + subscribe_announces_ok::SubscribeAnnouncesOk, }, moqt_payload::MOQTPayload, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; -use crate::modules::moqt_client::MOQTClient; - -pub(crate) async fn subscribe_namespace_handler( - subscribe_namespace_message: SubscribeNamespace, +pub(crate) async fn subscribe_announces_handler( + subscribe_announces_message: SubscribeAnnounces, client: &MOQTClient, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, -) -> Result> { - tracing::trace!("subscribe_namespace_handler start."); + control_message_dispatcher: &mut ControlMessageDispatcher, +) -> Result> { + tracing::trace!("subscribe_announces_handler start."); tracing::debug!( - "subscribe_namespace_message: {:#?}", - subscribe_namespace_message + "subscribe_announces_message: {:#?}", + subscribe_announces_message ); // TODO: auth - let track_namespace_prefix = subscribe_namespace_message.track_namespace_prefix().clone(); + let track_namespace_prefix = subscribe_announces_message.track_namespace_prefix().clone(); // Record the subscribed Track Namespace Prefix let set_result = pubsub_relation_manager_repository @@ -40,21 +38,20 @@ pub(crate) async fn subscribe_namespace_handler( match set_result { Ok(_) => { tracing::info!( - "subscribe_namespaced track_namespace_prefix: {:#?}", + "subscribe_announcesd track_namespace_prefix: {:#?}", track_namespace_prefix.clone() ); - tracing::trace!("subscribe_namespace_handler complete."); + tracing::trace!("subscribe_announces_handler complete."); - // Send SubscribeNamespaceOk message - let subscribe_namespace_ok_message: Box = - Box::new(SubscribeNamespaceOk::new(track_namespace_prefix.clone())); + // Send SubscribeAnnouncesOk message + let subscribe_announces_ok_message: Box = + Box::new(SubscribeAnnouncesOk::new(track_namespace_prefix.clone())); // TODO: Unify the method to send a message to the opposite client itself - let _ = send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + let _ = control_message_dispatcher + .transfer_message_to_control_message_sender_thread( client.id(), - subscribe_namespace_ok_message, - StreamDirection::Bi, + subscribe_announces_ok_message, ) .await; @@ -69,14 +66,13 @@ pub(crate) async fn subscribe_namespace_handler( // TODO: auth parameter let announce_message: Box = Box::new(Announce::new( namespace, - subscribe_namespace_message.parameters().clone(), + subscribe_announces_message.parameters().clone(), )); - let _ = send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + let _ = control_message_dispatcher + .transfer_message_to_control_message_sender_thread( client.id(), announce_message, - StreamDirection::Bi, ) .await; } @@ -87,12 +83,12 @@ pub(crate) async fn subscribe_namespace_handler( // TODO: Separate namespace prefix overlap error Err(err) => { let msg = std::format!( - "subscribe_namespace_handler: set namespace prefix err: {:?}", + "subscribe_announces_handler: set namespace prefix err: {:?}", err.to_string() ); tracing::error!(msg); - Ok(Some(SubscribeNamespaceError::new( + Ok(Some(SubscribeAnnouncesError::new( track_namespace_prefix, 1, msg, @@ -103,24 +99,23 @@ pub(crate) async fn subscribe_namespace_handler( #[cfg(test)] mod success { - use super::subscribe_namespace_handler; + use super::subscribe_announces_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use bytes::BytesMut; use moqt_core::{ - constants::StreamDirection, messages::{ control_messages::{ - subscribe_namespace::SubscribeNamespace, + subscribe_announces::SubscribeAnnounces, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, moqt_payload::MOQTPayload, @@ -132,7 +127,7 @@ mod success { #[tokio::test] async fn normal_case() { - // Generate SUBSCRIBE_NAMESPACE message + // Generate SUBSCRIBE_ANNOUNCES message let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); @@ -140,10 +135,10 @@ mod success { let parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(parameter_value)); let parameters = vec![parameter]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut buf); + subscribe_announces_message.packetize(&mut buf); // Generate client let upstream_session_id = 0; @@ -170,28 +165,30 @@ mod success { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; - // Execute subscribe_namespace_handler and get result - let result = subscribe_namespace_handler( - subscribe_namespace_message, + // Execute subscribe_announces_handler and get result + let result = subscribe_announces_handler( + subscribe_announces_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await .unwrap(); @@ -202,24 +199,23 @@ mod success { #[cfg(test)] mod failure { - use super::subscribe_namespace_handler; + use super::subscribe_announces_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use bytes::BytesMut; use moqt_core::{ - constants::StreamDirection, messages::{ control_messages::{ - subscribe_namespace::SubscribeNamespace, + subscribe_announces::SubscribeAnnounces, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, moqt_payload::MOQTPayload, @@ -231,7 +227,7 @@ mod failure { #[tokio::test] async fn same_prefix() { - // Generate SUBSCRIBE_NAMESPACE message + // Generate SUBSCRIBE_ANNOUNCES message let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); @@ -239,10 +235,10 @@ mod failure { let parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(parameter_value)); let parameters = vec![parameter]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut buf); + subscribe_announces_message.packetize(&mut buf); // Generate client let upstream_session_id = 0; @@ -273,38 +269,40 @@ mod failure { .set_downstream_subscribed_namespace_prefix(track_namespace_prefix.clone(), client.id()) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; - // Execute subscribe_namespace_handler and get result - let result = subscribe_namespace_handler( - subscribe_namespace_message, + // Execute subscribe_announces_handler and get result + let result = subscribe_announces_handler( + subscribe_announces_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await; match result { - Ok(Some(subscribe_namespace_error)) => { + Ok(Some(subscribe_announces_error)) => { assert_eq!( - *subscribe_namespace_error.track_namespace_prefix(), + *subscribe_announces_error.track_namespace_prefix(), track_namespace_prefix ); - assert_eq!(subscribe_namespace_error.error_code(), 1); + assert_eq!(subscribe_announces_error.error_code(), 1); } _ => panic!("Unexpected result: {:?}", result), } @@ -312,7 +310,7 @@ mod failure { #[tokio::test] async fn prefix_overlap_longer() { - // Generate SUBSCRIBE_NAMESPACE message + // Generate SUBSCRIBE_ANNOUNCES message let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); let exists_track_namespace_prefix = Vec::from(["aaa".to_string()]); @@ -321,10 +319,10 @@ mod failure { let parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(parameter_value)); let parameters = vec![parameter]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut buf); + subscribe_announces_message.packetize(&mut buf); // Generate client let upstream_session_id = 0; @@ -358,38 +356,40 @@ mod failure { ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; - // Execute subscribe_namespace_handler and get result - let result = subscribe_namespace_handler( - subscribe_namespace_message, + // Execute subscribe_announces_handler and get result + let result = subscribe_announces_handler( + subscribe_announces_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await; match result { - Ok(Some(subscribe_namespace_error)) => { + Ok(Some(subscribe_announces_error)) => { assert_eq!( - *subscribe_namespace_error.track_namespace_prefix(), + *subscribe_announces_error.track_namespace_prefix(), track_namespace_prefix ); - assert_eq!(subscribe_namespace_error.error_code(), 1); + assert_eq!(subscribe_announces_error.error_code(), 1); } _ => panic!("Unexpected result: {:?}", result), } @@ -397,7 +397,7 @@ mod failure { #[tokio::test] async fn prefix_overlap_shorter() { - // Generate SUBSCRIBE_NAMESPACE message + // Generate SUBSCRIBE_ANNOUNCES message let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); let exists_track_namespace_prefix = @@ -407,10 +407,10 @@ mod failure { let parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(parameter_value)); let parameters = vec![parameter]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut buf); + subscribe_announces_message.packetize(&mut buf); // Generate client let upstream_session_id = 0; @@ -444,38 +444,40 @@ mod failure { ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; - // Execute subscribe_namespace_handler and get result - let result = subscribe_namespace_handler( - subscribe_namespace_message, + // Execute subscribe_announces_handler and get result + let result = subscribe_announces_handler( + subscribe_announces_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await; match result { - Ok(Some(subscribe_namespace_error)) => { + Ok(Some(subscribe_announces_error)) => { assert_eq!( - *subscribe_namespace_error.track_namespace_prefix(), + *subscribe_announces_error.track_namespace_prefix(), track_namespace_prefix ); - assert_eq!(subscribe_namespace_error.error_code(), 1); + assert_eq!(subscribe_announces_error.error_code(), 1); } _ => panic!("Unexpected result: {:?}", result), } @@ -483,7 +485,7 @@ mod failure { #[tokio::test] async fn forward_fail() { - // Generate SUBSCRIBE_NAMESPACE message + // Generate SUBSCRIBE_ANNOUNCES message let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); @@ -491,10 +493,10 @@ mod failure { let parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(parameter_value)); let parameters = vec![parameter]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut buf); + subscribe_announces_message.packetize(&mut buf); // Generate client let upstream_session_id = 0; @@ -521,19 +523,22 @@ mod failure { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher (without set sender) - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher (without set sender) + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); - // Execute subscribe_namespace_handler and get result - let result = subscribe_namespace_handler( - subscribe_namespace_message, + // Execute subscribe_announces_handler and get result + let result = subscribe_announces_handler( + subscribe_announces_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await .unwrap(); @@ -543,7 +548,7 @@ mod failure { #[tokio::test] async fn namespace_not_found() { - // Generate SUBSCRIBE_NAMESPACE message + // Generate SUBSCRIBE_ANNOUNCES message let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_namespace_prefix = Vec::from(["ddd".to_string(), "eee".to_string()]); @@ -551,10 +556,10 @@ mod failure { let parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new(parameter_value)); let parameters = vec![parameter]; - let subscribe_namespace_message = - SubscribeNamespace::new(track_namespace_prefix.clone(), parameters); + let subscribe_announces_message = + SubscribeAnnounces::new(track_namespace_prefix.clone(), parameters); let mut buf = BytesMut::new(); - subscribe_namespace_message.packetize(&mut buf); + subscribe_announces_message.packetize(&mut buf); // Generate client let upstream_session_id = 0; @@ -581,28 +586,30 @@ mod failure { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; - // Execute subscribe_namespace_handler and get result - let result = subscribe_namespace_handler( - subscribe_namespace_message, + // Execute subscribe_announces_handler and get result + let result = subscribe_announces_handler( + subscribe_announces_message, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, ) .await .unwrap(); diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_error_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_error_handler.rs index 9931509d..53f43bba 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_error_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_error_handler.rs @@ -1,18 +1,16 @@ +use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, moqt_client::MOQTClient, +}; use anyhow::{bail, Result}; - use moqt_core::{ - constants::StreamDirection, messages::{control_messages::subscribe_error::SubscribeError, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; -use crate::modules::moqt_client::MOQTClient; - pub(crate) async fn subscribe_error_handler( subscribe_error_message: SubscribeError, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, client: &MOQTClient, ) -> Result<()> { tracing::trace!("subscribe_error_handler start."); @@ -61,11 +59,10 @@ pub(crate) async fn subscribe_error_handler( let forwarding_subscribe_error_message: Box = Box::new(message_payload.clone()); - send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + control_message_dispatcher + .transfer_message_to_control_message_sender_thread( *downstream_session_id, forwarding_subscribe_error_message, - StreamDirection::Bi, ) .await?; @@ -117,21 +114,20 @@ async fn delete_downstream_and_upstream_subscription( mod success { use super::subscribe_error_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use moqt_core::{ - constants::StreamDirection, messages::{ - control_messages::subscribe::{FilterType, GroupOrder}, control_messages::subscribe_error::{SubscribeError, SubscribeErrorCode}, + control_messages::{group_order::GroupOrder, subscribe::FilterType}, moqt_payload::MOQTPayload, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, @@ -152,7 +148,6 @@ mod success { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; // Generate client let upstream_session_id = 1; @@ -186,7 +181,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -207,7 +201,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -219,18 +212,20 @@ mod success { ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -250,7 +245,7 @@ mod success { let result = subscribe_error_handler( subscribe_error, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; @@ -263,19 +258,20 @@ mod success { mod failure { use super::subscribe_error_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use moqt_core::{ messages::control_messages::{ - subscribe::{FilterType, GroupOrder}, + group_order::GroupOrder, + subscribe::FilterType, subscribe_error::{SubscribeError, SubscribeErrorCode}, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, @@ -295,7 +291,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; // Generate client let upstream_session_id = 1; @@ -329,7 +324,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -350,7 +344,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -362,12 +355,15 @@ mod failure { ) .await; - // Generate SendStreamDispacher (without set sender) - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher (without set sender) + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let error_code = SubscribeErrorCode::InternalError; let reason_phrase = "test".to_string(); @@ -384,7 +380,7 @@ mod failure { let result = subscribe_error_handler( subscribe_error, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_handler.rs index a0266aa3..c8e4db9e 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_handler.rs @@ -1,5 +1,6 @@ use crate::{ modules::{ + control_message_dispatcher::ControlMessageDispatcher, moqt_client::MOQTClient, object_cache_storage::{cache::CacheKey, wrapper::ObjectCacheStorageWrapper}, }, @@ -7,19 +8,17 @@ use crate::{ }; use anyhow::{bail, Result}; use moqt_core::{ - constants::StreamDirection, data_stream_type::DataStreamType, messages::{ control_messages::{ - subscribe::Subscribe, + subscribe::{FilterType, Subscribe}, subscribe_error::{SubscribeError, SubscribeErrorCode}, subscribe_ok::SubscribeOk, }, moqt_payload::MOQTPayload, }, - models::tracks::ForwardingPreference, + models::{range::ObjectStart, tracks::ForwardingPreference}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; @@ -28,7 +27,7 @@ pub(crate) async fn subscribe_handler( subscribe_message: Subscribe, client: &MOQTClient, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, object_cache_storage: &mut ObjectCacheStorageWrapper, start_forwarder_txes: Arc>>, ) -> Result> { @@ -76,18 +75,26 @@ pub(crate) async fn subscribe_handler( // If the track already exists, return the track as it is if pubsub_relation_manager_repository - .is_track_existing( + .is_upstream_subscribed( subscribe_message.track_namespace().to_vec(), subscribe_message.track_name().to_string(), ) .await .unwrap() { - // Generate message -> Set subscription -> Send message - let subscribe_ok_message = match generate_subscribe_ok_message( + let (content_exists, largest_group_id, largest_object_id) = check_existing_contents( + &subscribe_message, pubsub_relation_manager_repository, object_cache_storage, + ) + .await?; + + // Generate message -> Set subscription -> Send message + let subscribe_ok_message = match generate_subscribe_ok_message( &subscribe_message, + content_exists, + largest_group_id, + largest_object_id, ) .await { @@ -139,15 +146,24 @@ pub(crate) async fn subscribe_handler( let subscribe_ok_payload: Box = Box::new(subscribe_ok_message.clone()); // TODO: Unify the method to send a message to the opposite client itself - send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( - client.id(), - subscribe_ok_payload, - StreamDirection::Bi, - ) + control_message_dispatcher + .transfer_message_to_control_message_sender_thread(client.id(), subscribe_ok_payload) .await?; - if subscribe_ok_message.content_exists() { + if content_exists { + // Store Largest Group/Object ID to culculate the Joining FETCH range + if subscribe_message.filter_type() == FilterType::LatestObject { + let actual_object_start = + ObjectStart::new(largest_group_id.unwrap(), largest_object_id.unwrap()); + pubsub_relation_manager_repository + .set_downstream_actual_object_start( + downstream_session_id, + downstream_subscribe_id, + actual_object_start, + ) + .await?; + } + start_new_forwarder( pubsub_relation_manager_repository, object_cache_storage, @@ -206,7 +222,6 @@ pub(crate) async fn subscribe_handler( subscribe_message.start_group(), subscribe_message.start_object(), subscribe_message.end_group(), - subscribe_message.end_object(), subscribe_message.subscribe_parameters().clone(), ) .unwrap(); @@ -217,11 +232,10 @@ pub(crate) async fn subscribe_handler( // Notify to the publisher about the SUBSCRIBE message // TODO: Wait for the SUBSCRIBE_OK message to be returned on a transaction // TODO: validate Timeout - match send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + match control_message_dispatcher + .transfer_message_to_control_message_sender_thread( session_id, forwarding_subscribe_message, - StreamDirection::Bi, ) .await { @@ -293,11 +307,6 @@ async fn start_new_forwarder( .await? .unwrap(); - let upstream_subscription = pubsub_relation_manager_repository - .get_upstream_subscription_by_ids(upstream_session_id, upstream_subscribe_id) - .await? - .unwrap(); - let start_forwarder_tx = start_forwarder_txes .lock() .await @@ -305,7 +314,11 @@ async fn start_new_forwarder( .unwrap() .clone(); - let forwarding_preference = upstream_subscription.get_forwarding_preference().unwrap(); + let forwarding_preference = pubsub_relation_manager_repository + .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) + .await? + .unwrap(); + match forwarding_preference { ForwardingPreference::Datagram => { let data_stream_type = DataStreamType::ObjectDatagram; @@ -313,20 +326,15 @@ async fn start_new_forwarder( .send((downstream_subscribe_id, data_stream_type, None)) .await; } - ForwardingPreference::Track => { - let data_stream_type = DataStreamType::StreamHeaderTrack; - let _ = start_forwarder_tx - .send((downstream_subscribe_id, data_stream_type, None)) - .await; - } - // If SUBSCRIBE message does not handle past objects, it is only necessary to open forwarders for subgroups of the current group ForwardingPreference::Subgroup => { - let data_stream_type = DataStreamType::StreamHeaderSubgroup; + let data_stream_type = DataStreamType::SubgroupHeader; let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); - let group_id = object_cache_storage - .get_largest_group_id(&cache_key) - .await?; + let group_id = match object_cache_storage.get_largest_group_id(&cache_key).await { + Ok(Some(group_id)) => group_id, + Ok(None) => bail!("largest group id is none"), + Err(err) => bail!("Failed to get largest group id: {:?}", err), + }; let start_group = subscribe_message.start_group(); if start_group.is_some() && start_group.unwrap() > group_id { @@ -335,6 +343,12 @@ async fn start_new_forwarder( return Ok(()); } + let end_group = subscribe_message.end_group(); + if end_group.is_some() && end_group.unwrap() < group_id { + // If the end_group is smaller than the largest group_id, there is no need to open forwarders + return Ok(()); + } + let subgroup_ids = object_cache_storage .get_all_subgroup_ids(&cache_key, group_id) .await?; @@ -355,11 +369,11 @@ async fn start_new_forwarder( Ok(()) } -async fn generate_subscribe_ok_message( +async fn check_existing_contents( + subscribe_message: &Subscribe, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, object_cache_storage: &mut ObjectCacheStorageWrapper, - subscribe_message: &Subscribe, -) -> Result { +) -> Result<(bool, Option, Option)> { let upstream_session_id = pubsub_relation_manager_repository .get_upstream_session_id(subscribe_message.track_namespace().clone()) .await? @@ -375,20 +389,26 @@ async fn generate_subscribe_ok_message( .unwrap(); let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); - let largest_group_id = match object_cache_storage.get_largest_group_id(&cache_key).await { - Ok(group_id) => Some(group_id), - Err(_) => None, - }; + let largest_group_id = + (object_cache_storage.get_largest_group_id(&cache_key).await).unwrap_or_default(); - let largest_object_id = match object_cache_storage.get_largest_object_id(&cache_key).await { - Ok(object_id) => Some(object_id), - Err(_) => None, - }; + let largest_object_id = + (object_cache_storage.get_largest_object_id(&cache_key).await).unwrap_or_default(); - // TODO: check cache duration - let expires = 0; // If the largest_group_id or largest_object_id is None, the content does not exist let content_exists = largest_group_id.is_some() && largest_object_id.is_some(); + + Ok((content_exists, largest_group_id, largest_object_id)) +} + +async fn generate_subscribe_ok_message( + subscribe_message: &Subscribe, + content_exists: bool, + largest_group_id: Option, + largest_object_id: Option, +) -> Result { + // TODO: check cache duration + let expires = 0; // TODO: check DELIVERY TIMEOUT let subscribe_parameters = vec![]; // TODO: accurate group_order @@ -423,16 +443,6 @@ async fn set_downstream_subscription( let downstream_start_group = subscribe_message.start_group(); let downstream_start_object = subscribe_message.start_object(); let downstream_end_group = subscribe_message.end_group(); - let downstream_end_object = subscribe_message.end_object(); - - // Get publisher subscription already exists - let upstream_subscription = pubsub_relation_manager_repository - .get_upstream_subscription_by_full_track_name( - downstream_track_namespace.clone(), - downstream_track_name.clone(), - ) - .await? - .unwrap(); pubsub_relation_manager_repository .set_downstream_subscription( @@ -447,17 +457,16 @@ async fn set_downstream_subscription( downstream_start_group, downstream_start_object, downstream_end_group, - downstream_end_object, ) .await?; let upstream_session_id = pubsub_relation_manager_repository - .get_upstream_session_id(downstream_track_namespace) + .get_upstream_session_id(downstream_track_namespace.clone()) .await? .unwrap(); - let (upstream_track_namespace, upstream_track_name) = - upstream_subscription.get_track_namespace_and_name(); + let upstream_track_namespace = downstream_track_namespace; + let upstream_track_name = downstream_track_name; // Get publisher subscribe id to register pubsub relation let upstream_subscribe_id = pubsub_relation_manager_repository @@ -502,7 +511,6 @@ async fn set_downstream_and_upstream_subscription( let downstream_start_group = subscribe_message.start_group(); let downstream_start_object = subscribe_message.start_object(); let downstream_end_group = subscribe_message.end_group(); - let downstream_end_object = subscribe_message.end_object(); pubsub_relation_manager_repository .set_downstream_subscription( @@ -517,7 +525,6 @@ async fn set_downstream_and_upstream_subscription( downstream_start_group, downstream_start_object, downstream_end_group, - downstream_end_object, ) .await?; @@ -532,7 +539,6 @@ async fn set_downstream_and_upstream_subscription( downstream_start_group, downstream_start_object, downstream_end_group, - downstream_end_object, ) .await?; @@ -554,6 +560,9 @@ mod success { use crate::SenderToOpenSubscription; use crate::{ modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, object_cache_storage::{ cache::CacheKey, commands::ObjectCacheStorageCommand, @@ -564,23 +573,20 @@ mod success { manager::pubsub_relation_manager, wrapper::{test_helper_fn, PubSubRelationManagerWrapper}, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }, SubgroupStreamId, }; + use moqt_core::messages::data_streams::subgroup_stream; use moqt_core::models::tracks::ForwardingPreference; use moqt_core::{ - constants::StreamDirection, data_stream_type::DataStreamType, messages::{ control_messages::{ - subscribe::{FilterType, GroupOrder, Subscribe}, + group_order::GroupOrder, + subscribe::{FilterType, Subscribe}, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, - data_streams::track_stream, moqt_payload::MOQTPayload, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, @@ -602,7 +608,6 @@ mod success { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -618,7 +623,6 @@ mod success { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -650,18 +654,20 @@ mod success { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -681,7 +687,7 @@ mod success { subscribe, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -711,6 +717,7 @@ mod success { } #[tokio::test] + // Return SUBSCRIBE_OK immediately but its ContentExists is false async fn normal_case_track_exists_and_content_not_exists() { // Generate SUBSCRIBE message let subscribe_id = 0; @@ -723,7 +730,6 @@ mod success { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -739,7 +745,6 @@ mod success { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -777,7 +782,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -786,25 +790,26 @@ mod success { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx.clone(), }) .await; - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -824,7 +829,7 @@ mod success { subscribe, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -852,6 +857,9 @@ mod success { } #[tokio::test] + // Return SUBSCRIBE_OK immediately + // ContentExists is true + // If, FilterType is LatestObject, the largest group_id and object_id are stored as actual_object_start async fn normal_case_track_exists_and_content_exists() { // Generate SUBSCRIBE message let subscribe_id = 0; @@ -860,11 +868,10 @@ mod success { let track_name = "track_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; - let filter_type = FilterType::LatestGroup; + let filter_type = FilterType::LatestObject; let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -880,7 +887,6 @@ mod success { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -918,11 +924,10 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); - let forwarding_preference = ForwardingPreference::Track; + let forwarding_preference = ForwardingPreference::Subgroup; let _ = pubsub_relation_manager .set_upstream_forwarding_preference( upstream_session_id, @@ -935,25 +940,26 @@ mod success { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx.clone(), }) .await; - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -965,28 +971,41 @@ mod success { let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); let group_id = 0; + let subgroup_id = 0; let object_status = None; let duration = 1000; let publisher_priority = 0; + let extension_headers = vec![]; - let track_header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); + let subgroup_header = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); let _ = object_cache_storage - .create_track_stream_cache(&cache_key, track_header) + .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, subgroup_header) .await; for i in 0..10 { let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; let object_id = i as u64; - let track_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); + let subgroup_object = subgroup_stream::Object::new( + object_id, + extension_headers.clone(), + object_status, + object_payload, + ) + .unwrap(); let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_object, duration) + .set_subgroup_stream_object( + &cache_key, + group_id, + subgroup_id, + subgroup_object, + duration, + ) .await; } @@ -1009,7 +1028,7 @@ mod success { subscribe, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -1034,6 +1053,15 @@ mod success { assert_eq!(downstream_session_id, downstream_session_id); assert_eq!(downstream_subscribe_id, downstream_subscribe_id); + + let actual_object_start = pubsub_relation_manager + .get_downstream_actual_object_start(*downstream_session_id, *downstream_subscribe_id) + .await + .unwrap() + .unwrap(); + + assert_eq!(actual_object_start.group_id(), group_id); + assert_eq!(actual_object_start.object_id(), 9); } } @@ -1041,6 +1069,9 @@ mod success { mod failure { use super::subscribe_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, object_cache_storage::{ commands::ObjectCacheStorageCommand, storage::object_cache_storage, @@ -1050,17 +1081,14 @@ mod failure { commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use crate::SenderToOpenSubscription; use moqt_core::{ - constants::StreamDirection, messages::{ control_messages::{ - subscribe::{FilterType, GroupOrder, Subscribe}, + group_order::GroupOrder, + subscribe::{FilterType, Subscribe}, subscribe_error::SubscribeErrorCode, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, @@ -1084,7 +1112,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -1100,7 +1127,6 @@ mod failure { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -1143,22 +1169,23 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -1178,7 +1205,7 @@ mod failure { subscribe, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -1201,7 +1228,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -1217,7 +1243,6 @@ mod failure { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -1248,12 +1273,15 @@ mod failure { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher (without set sender) - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher (without set sender) + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -1270,7 +1298,7 @@ mod failure { subscribe, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -1299,7 +1327,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -1315,7 +1342,6 @@ mod failure { start_group, start_object, end_group, - end_object, subscribe_parameters, ) .unwrap(); @@ -1339,18 +1365,20 @@ mod failure { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -1370,7 +1398,7 @@ mod failure { subscribe, &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -1399,7 +1427,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -1418,7 +1445,6 @@ mod failure { start_group, start_object, end_group, - end_object, subscribe_parameters.clone(), ) .unwrap(); @@ -1453,18 +1479,20 @@ mod failure { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -1484,7 +1512,7 @@ mod failure { subscribes[0].clone(), &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes.clone(), ) @@ -1494,7 +1522,7 @@ mod failure { subscribes[1].clone(), &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) @@ -1516,7 +1544,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -1535,7 +1562,6 @@ mod failure { start_group, start_object, end_group, - end_object, subscribe_parameters.clone(), ) .unwrap(); @@ -1570,18 +1596,20 @@ mod failure { .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -1601,7 +1629,7 @@ mod failure { subscribes[0].clone(), &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes.clone(), ) @@ -1611,7 +1639,7 @@ mod failure { subscribes[1].clone(), &client, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, start_forwarder_txes, ) diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_ok_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_ok_handler.rs index 62b4a119..031105a1 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_ok_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/subscribe_ok_handler.rs @@ -1,17 +1,16 @@ +use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, moqt_client::MOQTClient, +}; use anyhow::Result; use moqt_core::{ - constants::StreamDirection, messages::{control_messages::subscribe_ok::SubscribeOk, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; -use crate::modules::moqt_client::MOQTClient; - pub(crate) async fn subscribe_ok_handler( subscribe_ok_message: SubscribeOk, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, client: &MOQTClient, ) -> Result<()> { tracing::trace!("subscribe_ok_handler start."); @@ -49,11 +48,10 @@ pub(crate) async fn subscribe_ok_handler( ); let subscribe_ok_message: Box = Box::new(message_payload.clone()); - send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + control_message_dispatcher + .transfer_message_to_control_message_sender_thread( *downstream_session_id, subscribe_ok_message, - StreamDirection::Bi, ) .await?; @@ -90,22 +88,22 @@ pub(crate) async fn subscribe_ok_handler( mod success { use super::subscribe_ok_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use bytes::BytesMut; use moqt_core::{ - constants::StreamDirection, messages::{ control_messages::{ - subscribe::{FilterType, GroupOrder}, + group_order::GroupOrder, + subscribe::FilterType, subscribe_ok::SubscribeOk, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, @@ -133,7 +131,6 @@ mod success { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -182,7 +179,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -203,7 +199,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -215,18 +210,20 @@ mod success { ) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -235,7 +232,7 @@ mod success { let result = subscribe_ok_handler( subscribe_ok, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; @@ -248,22 +245,22 @@ mod success { mod failure { use super::subscribe_ok_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, pubsub_relation_manager::{ commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use bytes::BytesMut; use moqt_core::{ - constants::StreamDirection, messages::{ control_messages::{ - subscribe::{FilterType, GroupOrder}, + group_order::GroupOrder, + subscribe::FilterType, subscribe_ok::SubscribeOk, version_specific_parameters::{AuthorizationInfo, VersionSpecificParameter}, }, @@ -291,7 +288,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -340,7 +336,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -361,7 +356,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -373,18 +367,21 @@ mod failure { ) .await; - // Generate SendStreamDispacher (without set sender) - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher (without set sender) + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); // Execute subscribe_ok_handler and get result let result = subscribe_ok_handler( subscribe_ok, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; @@ -408,7 +405,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -457,23 +453,24 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await .unwrap(); - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -482,7 +479,7 @@ mod failure { let result = subscribe_ok_handler( subscribe_ok, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; @@ -507,7 +504,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -556,7 +552,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -577,7 +572,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -593,18 +587,20 @@ mod failure { .activate_downstream_subscription(downstream_session_id, downstream_subscribe_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -613,7 +609,7 @@ mod failure { let result = subscribe_ok_handler( subscribe_ok, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; @@ -638,7 +634,6 @@ mod failure { let start_group = None; let start_object = None; let end_group = None; - let end_object = None; let version_specific_parameter = VersionSpecificParameter::AuthorizationInfo(AuthorizationInfo::new("test".to_string())); let subscribe_parameters = vec![version_specific_parameter]; @@ -687,7 +682,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -708,7 +702,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -724,18 +717,20 @@ mod failure { .activate_upstream_subscription(upstream_session_id, upstream_subscribe_id) .await; - // Generate SendStreamDispacher - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); + // Generate ControlMessageDispacher + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let mut send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let mut control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_tx - .send(SendStreamDispatchCommand::Set { + let _ = control_message_dispatch_tx + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -744,7 +739,7 @@ mod failure { let result = subscribe_ok_handler( subscribe_ok, &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; diff --git a/moqt-server/src/modules/message_handlers/control_message/handlers/unsubscribe_handler.rs b/moqt-server/src/modules/message_handlers/control_message/handlers/unsubscribe_handler.rs index 10881dde..98d6e610 100644 --- a/moqt-server/src/modules/message_handlers/control_message/handlers/unsubscribe_handler.rs +++ b/moqt-server/src/modules/message_handlers/control_message/handlers/unsubscribe_handler.rs @@ -1,19 +1,19 @@ -use crate::modules::moqt_client::MOQTClient; +use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, moqt_client::MOQTClient, +}; use anyhow::Result; use moqt_core::{ - constants::StreamDirection, messages::control_messages::{ subscribe_done::{StatusCode, SubscribeDone}, unsubscribe::Unsubscribe, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; pub(crate) async fn unsubscribe_handler( unsubscribe_message: Unsubscribe, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, client: &MOQTClient, ) -> Result<()> { tracing::trace!("unsubscribe_handler start."); @@ -47,12 +47,8 @@ pub(crate) async fn unsubscribe_handler( None, None, )); - send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( - client.id(), - subscribe_done_message, - StreamDirection::Bi, - ) + control_message_dispatcher + .transfer_message_to_control_message_sender_thread(client.id(), subscribe_done_message) .await?; // 3. If the number of DownStream Subscriptions is zero, send UNSUBSCRIBE to the Original Publisher. @@ -61,11 +57,10 @@ pub(crate) async fn unsubscribe_handler( .await?; if downstream_subscribers.is_empty() { let unsubscribe_message = Box::new(Unsubscribe::new(upstream_subscribe_id)); - send_stream_dispatcher_repository - .transfer_message_to_send_stream_thread( + control_message_dispatcher + .transfer_message_to_control_message_sender_thread( upstream_session_id, unsubscribe_message, - StreamDirection::Bi, ) .await?; } @@ -80,6 +75,9 @@ mod success { use super::unsubscribe_handler; use crate::modules::{ + control_message_dispatcher::{ + control_message_dispatcher, ControlMessageDispatchCommand, ControlMessageDispatcher, + }, moqt_client::MOQTClient, object_cache_storage::{ commands::ObjectCacheStorageCommand, storage::object_cache_storage, @@ -89,16 +87,12 @@ mod success { commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }, - send_stream_dispatcher::{ - send_stream_dispatcher, SendStreamDispatchCommand, SendStreamDispatcher, - }, server_processes::senders, }; use crate::SenderToOpenSubscription; - use moqt_core::constants::StreamDirection; use moqt_core::{ messages::{ - control_messages::subscribe::{FilterType, GroupOrder}, + control_messages::{group_order::GroupOrder, subscribe::FilterType}, moqt_payload::MOQTPayload, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, @@ -115,12 +109,15 @@ mod success { pubsub_relation_manager_wrapper } - async fn spawn_send_stream_dispatcher() -> SendStreamDispatcher { - let (send_stream_tx, mut send_stream_rx) = mpsc::channel::(1024); - tokio::spawn(async move { send_stream_dispatcher(&mut send_stream_rx).await }); - let send_stream_dispatcher: SendStreamDispatcher = - SendStreamDispatcher::new(send_stream_tx.clone()); - send_stream_dispatcher + async fn spawn_control_message_dispatcher() -> ControlMessageDispatcher { + let (control_message_dispatch_tx, mut control_message_dispatch_rx) = + mpsc::channel::(1024); + tokio::spawn( + async move { control_message_dispatcher(&mut control_message_dispatch_rx).await }, + ); + let control_message_dispatcher: ControlMessageDispatcher = + ControlMessageDispatcher::new(control_message_dispatch_tx.clone()); + control_message_dispatcher } async fn spawn_object_cache_storage() -> ObjectCacheStorageWrapper { @@ -143,12 +140,12 @@ mod success { async fn initialize() -> ( PubSubRelationManagerWrapper, - SendStreamDispatcher, + ControlMessageDispatcher, ObjectCacheStorageWrapper, Arc>>, MOQTClient, ) { - let send_stream_dispatcher = spawn_send_stream_dispatcher().await; + let control_message_dispatcher = spawn_control_message_dispatcher().await; let object_cache_storage_wrapper = spawn_object_cache_storage().await; let pubsub_relation_manager_wrapper = spawn_pubsub_relation_manager().await; let start_forwarder_txes = create_start_fowarder_txes().await; @@ -156,7 +153,7 @@ mod success { ( pubsub_relation_manager_wrapper, - send_stream_dispatcher, + control_message_dispatcher, object_cache_storage_wrapper, start_forwarder_txes, client, @@ -165,17 +162,16 @@ mod success { async fn setup_upstream_subscription( pubsub_relation_manager_wrapper: PubSubRelationManagerWrapper, - send_stream_dispatcher: SendStreamDispatcher, + control_message_dispatcher: ControlMessageDispatcher, upstream_session_id: usize, track_namespace: Vec, track_name: String, ) -> u64 { let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_dispatcher + let _ = control_message_dispatcher .get_tx() - .send(SendStreamDispatchCommand::Set { + .send(ControlMessageDispatchCommand::Set { session_id: upstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -196,7 +192,6 @@ mod success { None, None, None, - None, ) .await .unwrap(); @@ -206,18 +201,17 @@ mod success { async fn setup_downstream_subscription( pubsub_relation_manager_wrapper: PubSubRelationManagerWrapper, - send_stream_dispatcher: SendStreamDispatcher, + control_message_dispatcher: ControlMessageDispatcher, downstream_session_id: usize, downstream_subscribe_id: u64, track_namespace: Vec, track_name: String, ) { let (message_tx, _) = mpsc::channel::>>(1024); - let _ = send_stream_dispatcher + let _ = control_message_dispatcher .get_tx() - .send(SendStreamDispatchCommand::Set { + .send(ControlMessageDispatchCommand::Set { session_id: downstream_session_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await; @@ -237,14 +231,13 @@ mod success { None, None, None, - None, ) .await; } async fn setup_e2e_subscription( pubsub_relation_manager_wrapper: PubSubRelationManagerWrapper, - send_stream_dispatcher: SendStreamDispatcher, + control_message_dispatcher: ControlMessageDispatcher, ) -> (usize, u64, usize, u64) { let upstream_session_id = 0; let downstream_session_id = 10; @@ -254,7 +247,7 @@ mod success { let upstream_subscribe_id = setup_upstream_subscription( pubsub_relation_manager_wrapper.clone(), - send_stream_dispatcher.clone(), + control_message_dispatcher.clone(), upstream_session_id, track_namespace.clone(), track_name.clone(), @@ -262,7 +255,7 @@ mod success { .await; setup_downstream_subscription( pubsub_relation_manager_wrapper.clone(), - send_stream_dispatcher.clone(), + control_message_dispatcher.clone(), downstream_session_id, downstream_subscribe_id, track_namespace, @@ -295,7 +288,7 @@ mod success { let ( mut pubsub_relation_manager_wrapper, - mut send_stream_dispatcher, + mut control_message_dispatcher, _object_cache_storage, _start_forwarder_txes, client, @@ -307,14 +300,14 @@ mod success { _downstream_subscribe_id, ) = setup_e2e_subscription( pubsub_relation_manager_wrapper.clone(), - send_stream_dispatcher.clone(), + control_message_dispatcher.clone(), ) .await; let result = unsubscribe_handler( unsubscribe, &mut pubsub_relation_manager_wrapper, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; @@ -334,7 +327,7 @@ mod success { let ( mut pubsub_relation_manager_wrapper, - mut send_stream_dispatcher, + mut control_message_dispatcher, _object_cache_storage, _start_forwarder_txes, client, @@ -347,7 +340,7 @@ mod success { _downstream_subscribe_id, ) = setup_e2e_subscription( pubsub_relation_manager_wrapper.clone(), - send_stream_dispatcher.clone(), + control_message_dispatcher.clone(), ) .await; @@ -356,7 +349,7 @@ mod success { setup_downstream_subscription( pubsub_relation_manager_wrapper.clone(), - send_stream_dispatcher.clone(), + control_message_dispatcher.clone(), second_downstream_session_id, second_downstream_subscribe_id, Vec::from(["test".to_string(), "test".to_string()]), @@ -376,7 +369,7 @@ mod success { let result = unsubscribe_handler( unsubscribe, &mut pubsub_relation_manager_wrapper, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &client, ) .await; diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes.rs index f0b10227..ab60b8d6 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes.rs @@ -2,8 +2,8 @@ pub(crate) mod announce_error_message; pub(crate) mod announce_message; pub(crate) mod announce_ok_message; pub(crate) mod client_setup_message; +pub(crate) mod subscribe_announces_message; pub(crate) mod subscribe_error_message; pub(crate) mod subscribe_message; -pub(crate) mod subscribe_namespace_message; pub(crate) mod subscribe_ok_message; pub(crate) mod unsubscribe_message; diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes/announce_message.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes/announce_message.rs index a9fe7d7f..d44d0a6b 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes/announce_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes/announce_message.rs @@ -1,18 +1,16 @@ +use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, + message_handlers::control_message::handlers::announce_handler::announce_handler, + moqt_client::MOQTClient, +}; use anyhow::{bail, Result}; use bytes::BytesMut; - use moqt_core::{ messages::{ control_messages::{announce::Announce, announce_error::AnnounceError}, moqt_payload::MOQTPayload, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - send_stream_dispatcher_repository::SendStreamDispatcherRepository, -}; - -use crate::modules::{ - message_handlers::control_message::handlers::announce_handler::announce_handler, - moqt_client::MOQTClient, }; pub(crate) async fn process_announce_message( @@ -20,7 +18,7 @@ pub(crate) async fn process_announce_message( client: &MOQTClient, write_buf: &mut BytesMut, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, ) -> Result> { let announce_message = match Announce::depacketize(payload_buf) { Ok(announce_message) => announce_message, @@ -34,7 +32,7 @@ pub(crate) async fn process_announce_message( announce_message, client, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, ) .await; diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_namespace_message.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_announces_message.rs similarity index 50% rename from moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_namespace_message.rs rename to moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_announces_message.rs index 16eedc04..16616e9e 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_namespace_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_announces_message.rs @@ -4,52 +4,52 @@ use bytes::BytesMut; use moqt_core::{ messages::{ control_messages::{ - subscribe_namespace::SubscribeNamespace, - subscribe_namespace_error::SubscribeNamespaceError, + subscribe_announces::SubscribeAnnounces, + subscribe_announces_error::SubscribeAnnouncesError, }, moqt_payload::MOQTPayload, }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; use crate::modules::{ - message_handlers::control_message::handlers::subscribe_namespace_handler::subscribe_namespace_handler, + control_message_dispatcher::ControlMessageDispatcher, + message_handlers::control_message::handlers::subscribe_announces_handler::subscribe_announces_handler, moqt_client::MOQTClient, }; -pub(crate) async fn process_subscribe_namespace_message( +pub(crate) async fn process_subscribe_announces_message( payload_buf: &mut BytesMut, client: &MOQTClient, write_buf: &mut BytesMut, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, -) -> Result> { - let subscribe_namespace_message = match SubscribeNamespace::depacketize(payload_buf) { - Ok(subscribe_namespace_message) => subscribe_namespace_message, + control_message_dispatcher: &mut ControlMessageDispatcher, +) -> Result> { + let subscribe_announces_message = match SubscribeAnnounces::depacketize(payload_buf) { + Ok(subscribe_announces_message) => subscribe_announces_message, Err(err) => { tracing::error!("{:#?}", err); bail!(err.to_string()); } }; - let result = subscribe_namespace_handler( - subscribe_namespace_message, + let result = subscribe_announces_handler( + subscribe_announces_message, client, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, ) .await; match result.as_ref() { - Ok(Some(subscribe_namespace_error)) => { - subscribe_namespace_error.packetize(write_buf); + Ok(Some(subscribe_announces_error)) => { + subscribe_announces_error.packetize(write_buf); result } Ok(None) => result, Err(err) => { - tracing::error!("subscribe_namespace_handler: err: {:?}", err.to_string()); + tracing::error!("subscribe_announces_handler: err: {:?}", err.to_string()); bail!(err.to_string()); } } diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_error_message.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_error_message.rs index 7648273d..e668b0b4 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_error_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_error_message.rs @@ -1,4 +1,5 @@ use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, message_handlers::control_message::handlers::subscribe_error_handler::subscribe_error_handler, moqt_client::MOQTClient, }; @@ -7,13 +8,12 @@ use bytes::BytesMut; use moqt_core::{ messages::{control_messages::subscribe_error::SubscribeError, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; pub(crate) async fn process_subscribe_error_message( payload_buf: &mut BytesMut, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, client: &MOQTClient, ) -> Result<()> { let subscribe_error_message = match SubscribeError::depacketize(payload_buf) { @@ -27,7 +27,7 @@ pub(crate) async fn process_subscribe_error_message( subscribe_error_handler( subscribe_error_message, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, client, ) .await diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_message.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_message.rs index 21595a2c..7ea7b5f1 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_message.rs @@ -1,3 +1,4 @@ +use crate::modules::control_message_dispatcher::ControlMessageDispatcher; use crate::modules::moqt_client::MOQTClient; use crate::modules::{ message_handlers::control_message::handlers::subscribe_handler::subscribe_handler, @@ -10,7 +11,6 @@ use moqt_core::{ messages::control_messages::subscribe_error::SubscribeError, messages::{control_messages::subscribe::Subscribe, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; @@ -20,7 +20,7 @@ pub(crate) async fn process_subscribe_message( client: &MOQTClient, write_buf: &mut BytesMut, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, object_cache_storage: &mut ObjectCacheStorageWrapper, start_forwarder_txes: Arc>>, ) -> Result> { @@ -36,7 +36,7 @@ pub(crate) async fn process_subscribe_message( subscribe_request_message, client, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, object_cache_storage, start_forwarder_txes, ) diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_ok_message.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_ok_message.rs index 57c70e3c..c3f5d2af 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_ok_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes/subscribe_ok_message.rs @@ -4,10 +4,10 @@ use bytes::BytesMut; use moqt_core::{ messages::{control_messages::subscribe_ok::SubscribeOk, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, message_handlers::control_message::handlers::subscribe_ok_handler::subscribe_ok_handler, moqt_client::MOQTClient, }; @@ -15,7 +15,7 @@ use crate::modules::{ pub(crate) async fn process_subscribe_ok_message( payload_buf: &mut BytesMut, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, client: &MOQTClient, ) -> Result<()> { let subscribe_ok_message = match SubscribeOk::depacketize(payload_buf) { @@ -29,7 +29,7 @@ pub(crate) async fn process_subscribe_ok_message( subscribe_ok_handler( subscribe_ok_message, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, client, ) .await diff --git a/moqt-server/src/modules/message_handlers/control_message/server_processes/unsubscribe_message.rs b/moqt-server/src/modules/message_handlers/control_message/server_processes/unsubscribe_message.rs index 48185c50..801959bb 100644 --- a/moqt-server/src/modules/message_handlers/control_message/server_processes/unsubscribe_message.rs +++ b/moqt-server/src/modules/message_handlers/control_message/server_processes/unsubscribe_message.rs @@ -1,4 +1,5 @@ use crate::modules::{ + control_message_dispatcher::ControlMessageDispatcher, message_handlers::control_message::handlers::unsubscribe_handler::unsubscribe_handler, moqt_client::MOQTClient, }; @@ -7,13 +8,12 @@ use bytes::BytesMut; use moqt_core::{ messages::{control_messages::unsubscribe::Unsubscribe, moqt_payload::MOQTPayload}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, - SendStreamDispatcherRepository, }; pub(crate) async fn process_unsubscribe_message( payload_buf: &mut BytesMut, pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, - send_stream_dispatcher_repository: &mut dyn SendStreamDispatcherRepository, + control_message_dispatcher: &mut ControlMessageDispatcher, client: &MOQTClient, ) -> Result<()> { let unsubscribe_message = match Unsubscribe::depacketize(payload_buf) { @@ -27,7 +27,7 @@ pub(crate) async fn process_unsubscribe_message( unsubscribe_handler( unsubscribe_message, pubsub_relation_manager_repository, - send_stream_dispatcher_repository, + control_message_dispatcher, client, ) .await diff --git a/moqt-server/src/modules/message_handlers/datagram_object.rs b/moqt-server/src/modules/message_handlers/datagram_object.rs index 545ae318..48425fe4 100644 --- a/moqt-server/src/modules/message_handlers/datagram_object.rs +++ b/moqt-server/src/modules/message_handlers/datagram_object.rs @@ -3,6 +3,8 @@ use crate::modules::moqt_client::MOQTClient; use crate::modules::moqt_client::MOQTClientStatus; use anyhow::{bail, Result}; use bytes::{Buf, BytesMut}; +use moqt_core::messages::data_streams::datagram_status; +use moqt_core::messages::data_streams::DatagramObject; use moqt_core::{ data_stream_type::DataStreamType, messages::data_streams::{datagram, DataStreams}, @@ -14,7 +16,7 @@ use tokio::sync::Mutex; #[derive(Debug, PartialEq)] pub enum DatagramObjectProcessResult { - Success(datagram::Object), + Success(DatagramObject), Continue, Failure(TerminationErrorCode, String), } @@ -31,6 +33,26 @@ fn read_data_stream_type(read_cur: &mut std::io::Cursor<&[u8]>) -> Result) -> Result { + match datagram::Object::depacketize(read_cur) { + Ok(object) => Ok(DatagramObject::ObjectDatagram(object)), + Err(err) => { + bail!(err.to_string()); + } + } +} + +fn depacketize_object_datagram_status( + read_cur: &mut std::io::Cursor<&[u8]>, +) -> Result { + match datagram_status::Object::depacketize(read_cur) { + Ok(object) => Ok(DatagramObject::ObjectDatagramStatus(object)), + Err(err) => { + bail!(err.to_string()); + } + } +} + pub(crate) async fn try_read_object( buf: &mut BytesMut, client: Arc>, @@ -69,8 +91,9 @@ pub(crate) async fn try_read_object( } }; - let result = match data_stream_type { - DataStreamType::ObjectDatagram => datagram::Object::depacketize(&mut read_cur), + let depacketize_result = match data_stream_type { + DataStreamType::ObjectDatagram => depacketize_object_datagram(&mut read_cur), + DataStreamType::ObjectDatagramStatus => depacketize_object_datagram_status(&mut read_cur), _ => { return DatagramObjectProcessResult::Failure( TerminationErrorCode::ProtocolViolation, @@ -78,7 +101,8 @@ pub(crate) async fn try_read_object( ); } }; - match result { + + match depacketize_result { Ok(object) => { buf.advance(read_cur.position() as usize); DatagramObjectProcessResult::Success(object) @@ -103,7 +127,7 @@ mod tests { use bytes::BytesMut; use moqt_core::{ data_stream_type::DataStreamType, - messages::data_streams::{datagram, DataStreams}, + messages::data_streams::{datagram, datagram_status, DataStreams, DatagramObject}, variable_integer::write_variable_integer, }; use std::{io::Cursor, sync::Arc}; @@ -113,11 +137,11 @@ mod tests { async fn datagram_object_success() { let data_stream_type = DataStreamType::ObjectDatagram; let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) 3, // Object ID (i) 4, // Subscriber Priority (8) + 0, // Extension Headers Length (i) 3, // Object Payload Length (i) 0, 1, 2, // Object Payload (..) ]; @@ -134,10 +158,13 @@ mod tests { let result = try_read_object(&mut buf, client).await; + println!("{:?}", result); + let mut buf_without_type = BytesMut::with_capacity(bytes_array.len()); buf_without_type.extend_from_slice(&bytes_array); let mut read_cur = Cursor::new(&buf_without_type[..]); let datagram_object = datagram::Object::depacketize(&mut read_cur).unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); assert_eq!( result, @@ -146,40 +173,16 @@ mod tests { } #[tokio::test] - async fn datagram_object_continue_insufficient_payload() { - let data_stream_type = DataStreamType::ObjectDatagram; + async fn datagram_object_status_success() { + let data_stream_type = DataStreamType::ObjectDatagramStatus; let bytes_array = [ - 0, // Subscribe ID (i) - 1, // Track Alias (i) - 2, // Group ID (i) - 3, // Object ID (i) - 4, // Subscriber Priority (8) - 50, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len() + 8); - buf.extend(write_variable_integer(data_stream_type as u64)); - buf.extend_from_slice(&bytes_array); - - let senders_mock = senders::test_helper_fn::create_senders_mock(); - let upstream_session_id = 0; - - let mut client = MOQTClient::new(upstream_session_id, senders_mock); - client.update_status(MOQTClientStatus::SetUp); - let client = Arc::new(Mutex::new(client)); - - let result = try_read_object(&mut buf, client).await; - - assert_eq!(result, DatagramObjectProcessResult::Continue); - } - - #[tokio::test] - async fn datagram_object_continue_incomplete_message() { - let data_stream_type = DataStreamType::ObjectDatagram; - let bytes_array = [ - 0, // Subscribe ID (i) 1, // Track Alias (i) 2, // Group ID (i) + 3, // Object ID (i) + 4, // Subscriber Priority (8) + 0, // Extension Headers Length (i) + 3, // Object Payload Length (i) + 1, // Object Status (i) ]; let mut buf = BytesMut::with_capacity(bytes_array.len() + 8); buf.extend(write_variable_integer(data_stream_type as u64)); @@ -194,7 +197,18 @@ mod tests { let result = try_read_object(&mut buf, client).await; - assert_eq!(result, DatagramObjectProcessResult::Continue); + println!("{:?}", result); + + let mut buf_without_type = BytesMut::with_capacity(bytes_array.len()); + buf_without_type.extend_from_slice(&bytes_array); + let mut read_cur = Cursor::new(&buf_without_type[..]); + let datagram_object = datagram_status::Object::depacketize(&mut read_cur).unwrap(); + let datagram_object = DatagramObject::ObjectDatagramStatus(datagram_object); + + assert_eq!( + result, + DatagramObjectProcessResult::Success(datagram_object) + ); } } } diff --git a/moqt-server/src/modules/message_handlers/stream_object.rs b/moqt-server/src/modules/message_handlers/stream_object.rs deleted file mode 100644 index 907f5a53..00000000 --- a/moqt-server/src/modules/message_handlers/stream_object.rs +++ /dev/null @@ -1,186 +0,0 @@ -use crate::constants::TerminationErrorCode; -use bytes::{Buf, BytesMut}; -use moqt_core::{ - data_stream_type::DataStreamType, - messages::data_streams::{subgroup_stream, track_stream, DataStreams}, -}; -use std::io::Cursor; - -#[derive(Debug, PartialEq)] -pub enum StreamObjectProcessResult { - Success(StreamObject), - Continue, - Failure(TerminationErrorCode, String), -} - -#[derive(Debug, PartialEq)] -pub enum StreamObject { - Track(track_stream::Object), - Subgroup(subgroup_stream::Object), -} - -pub async fn try_read_object( - buf: &mut BytesMut, - data_stream_type: DataStreamType, -) -> StreamObjectProcessResult { - let payload_length = buf.len(); - tracing::trace!("stream_object_handler! {}", payload_length); - - // Check if the data is exist - if payload_length == 0 { - return StreamObjectProcessResult::Continue; - } - - let mut read_cur = Cursor::new(&buf[..]); - let result = match data_stream_type { - DataStreamType::StreamHeaderTrack => { - track_stream::Object::depacketize(&mut read_cur).map(StreamObject::Track) - } - DataStreamType::StreamHeaderSubgroup => { - subgroup_stream::Object::depacketize(&mut read_cur).map(StreamObject::Subgroup) - } - unknown => { - return StreamObjectProcessResult::Failure( - TerminationErrorCode::ProtocolViolation, - format!("Unknown message type: {:?}", unknown), - ); - } - }; - - match result { - Ok(stream_object) => { - buf.advance(read_cur.position() as usize); - StreamObjectProcessResult::Success(stream_object) - } - Err(err) => { - tracing::warn!("{:#?}", err); - // Reset the cursor position because data for an object has not yet arrived - read_cur.set_position(0); - StreamObjectProcessResult::Continue - } - } -} - -#[cfg(test)] -mod tests { - mod success { - use crate::modules::message_handlers::stream_object::{ - try_read_object, StreamObject, StreamObjectProcessResult, - }; - use bytes::BytesMut; - use moqt_core::{ - data_stream_type::DataStreamType, - messages::data_streams::{subgroup_stream, track_stream, DataStreams}, - }; - use std::io::Cursor; - - #[tokio::test] - async fn stream_object_track_success() { - let data_stream_type = DataStreamType::StreamHeaderTrack; - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 3, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let buf_clone = buf.clone(); - - let result = try_read_object(&mut buf, data_stream_type).await; - - let mut read_cur = Cursor::new(&buf_clone[..]); - let object = track_stream::Object::depacketize(&mut read_cur).unwrap(); - - assert_eq!( - result, - StreamObjectProcessResult::Success(StreamObject::Track(object)) - ); - } - - #[tokio::test] - async fn stream_object_subgroup_success() { - let data_stream_type = DataStreamType::StreamHeaderSubgroup; - let bytes_array = [ - 0, // Object ID (i) - 3, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - let buf_clone = buf.clone(); - - let result = try_read_object(&mut buf, data_stream_type).await; - - let mut read_cur = Cursor::new(&buf_clone[..]); - let object = subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); - - assert_eq!( - result, - StreamObjectProcessResult::Success(StreamObject::Subgroup(object)) - ); - } - - #[tokio::test] - async fn stream_object_track_continue_insufficient_payload() { - let data_stream_type = DataStreamType::StreamHeaderTrack; - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - 50, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - - let result = try_read_object(&mut buf, data_stream_type).await; - - assert_eq!(result, StreamObjectProcessResult::Continue); - } - - #[tokio::test] - async fn stream_object_subgroup_continue_insufficient_payload() { - let data_stream_type = DataStreamType::StreamHeaderSubgroup; - let bytes_array = [ - 0, // Object ID (i) - 50, // Object Payload Length (i) - 0, 1, 2, // Object Payload (..) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - - let result = try_read_object(&mut buf, data_stream_type).await; - - assert_eq!(result, StreamObjectProcessResult::Continue); - } - - #[tokio::test] - async fn stream_object_track_continue_incomplete_message() { - let data_stream_type = DataStreamType::StreamHeaderTrack; - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - - let result = try_read_object(&mut buf, data_stream_type).await; - - assert_eq!(result, StreamObjectProcessResult::Continue); - } - - #[tokio::test] - async fn stream_object_subgroup_continue_incomplete_message() { - let data_stream_type = DataStreamType::StreamHeaderSubgroup; - let bytes_array = [ - 0, // Object ID (i) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len()); - buf.extend_from_slice(&bytes_array); - - let result = try_read_object(&mut buf, data_stream_type).await; - - assert_eq!(result, StreamObjectProcessResult::Continue); - } - } -} diff --git a/moqt-server/src/modules/message_handlers/stream_header.rs b/moqt-server/src/modules/message_handlers/subgroup_stream_header.rs similarity index 54% rename from moqt-server/src/modules/message_handlers/stream_header.rs rename to moqt-server/src/modules/message_handlers/subgroup_stream_header.rs index adbd1d07..f71f270f 100644 --- a/moqt-server/src/modules/message_handlers/stream_header.rs +++ b/moqt-server/src/modules/message_handlers/subgroup_stream_header.rs @@ -6,25 +6,19 @@ use anyhow::{bail, Result}; use bytes::{Buf, BytesMut}; use moqt_core::{ data_stream_type::DataStreamType, - messages::data_streams::{subgroup_stream, track_stream::Header, DataStreams}, + messages::data_streams::{subgroup_stream, DataStreams}, variable_integer::read_variable_integer, }; use std::{io::Cursor, sync::Arc}; use tokio::sync::Mutex; #[derive(Debug, PartialEq)] -pub enum StreamHeaderProcessResult { - Success(StreamHeader), +pub enum SubgroupStreamHeaderProcessResult { + Success(subgroup_stream::Header), Continue, Failure(TerminationErrorCode, String), } -#[derive(Debug, PartialEq)] -pub enum StreamHeader { - Track(Header), - Subgroup(subgroup_stream::Header), -} - fn read_data_stream_type(read_cur: &mut std::io::Cursor<&[u8]>) -> Result { let type_value = match read_variable_integer(read_cur) { Ok(v) => v as u8, @@ -41,13 +35,13 @@ fn read_data_stream_type(read_cur: &mut std::io::Cursor<&[u8]>) -> Result>, -) -> StreamHeaderProcessResult { +) -> SubgroupStreamHeaderProcessResult { let payload_length = buf.len(); tracing::trace!("try to read stream header! {}", payload_length); // Check if the data stream type is exist if payload_length == 0 { - return StreamHeaderProcessResult::Continue; + return SubgroupStreamHeaderProcessResult::Continue; } // check subscription and judge if it is invalid timing @@ -55,7 +49,7 @@ pub async fn try_read_header( if client_status != MOQTClientStatus::SetUp { let message = String::from("Invalid timing"); tracing::error!(message); - return StreamHeaderProcessResult::Failure( + return SubgroupStreamHeaderProcessResult::Failure( TerminationErrorCode::ProtocolViolation, message, ); @@ -70,7 +64,7 @@ pub async fn try_read_header( buf.advance(read_cur.position() as usize); tracing::error!("data_stream_type is wrong: {:?}", err); - return StreamHeaderProcessResult::Failure( + return SubgroupStreamHeaderProcessResult::Failure( TerminationErrorCode::ProtocolViolation, err.to_string(), ); @@ -78,31 +72,23 @@ pub async fn try_read_header( }; tracing::info!("Received data stream type: {:?}", data_stream_type); - let result = match data_stream_type { - DataStreamType::StreamHeaderTrack => { - Header::depacketize(&mut read_cur).map(StreamHeader::Track) - } - DataStreamType::StreamHeaderSubgroup => { - subgroup_stream::Header::depacketize(&mut read_cur).map(StreamHeader::Subgroup) - } - unknown => { - return StreamHeaderProcessResult::Failure( - TerminationErrorCode::ProtocolViolation, - format!("Unknown message type: {:?}", unknown), - ); - } - }; + if data_stream_type != DataStreamType::SubgroupHeader { + return SubgroupStreamHeaderProcessResult::Failure( + TerminationErrorCode::ProtocolViolation, + format!("Unknown message type: {:?}", data_stream_type), + ); + } - match result { + match subgroup_stream::Header::depacketize(&mut read_cur) { Ok(stream_header) => { buf.advance(read_cur.position() as usize); - StreamHeaderProcessResult::Success(stream_header) + SubgroupStreamHeaderProcessResult::Success(stream_header) } Err(err) => { tracing::warn!("{:#?}", err); // Reset the cursor position because data for an object has not yet arrived read_cur.set_position(0); - StreamHeaderProcessResult::Continue + SubgroupStreamHeaderProcessResult::Continue } } } @@ -111,8 +97,8 @@ pub async fn try_read_header( mod tests { mod success { use crate::modules::{ - message_handlers::stream_header::{ - try_read_header, StreamHeader, StreamHeaderProcessResult, + message_handlers::subgroup_stream_header::{ + try_read_header, SubgroupStreamHeaderProcessResult, }, moqt_client::{MOQTClient, MOQTClientStatus}, server_processes::senders, @@ -120,47 +106,15 @@ mod tests { use bytes::BytesMut; use moqt_core::{ data_stream_type::DataStreamType, - messages::data_streams::{subgroup_stream, track_stream::Header, DataStreams}, + messages::data_streams::{subgroup_stream, DataStreams}, variable_integer::write_variable_integer, }; use std::{io::Cursor, sync::Arc}; use tokio::sync::Mutex; - #[tokio::test] - async fn track_stream_header_success() { - let data_stream_type = DataStreamType::StreamHeaderTrack; - let bytes_array = [ - 0, // Subscribe ID (i) - 1, // Track Alias (i) - 2, // Subscriber Priority (8) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len() + 8); - buf.extend(write_variable_integer(data_stream_type as u64)); - buf.extend_from_slice(&bytes_array); - - let senders_mock = senders::test_helper_fn::create_senders_mock(); - let upstream_session_id = 0; - - let mut client = MOQTClient::new(upstream_session_id, senders_mock); - client.update_status(MOQTClientStatus::SetUp); - let client = Arc::new(Mutex::new(client)); - - let result = try_read_header(&mut buf, client).await; - - let mut buf_without_type = BytesMut::with_capacity(bytes_array.len()); - buf_without_type.extend_from_slice(&bytes_array); - let mut read_cur = Cursor::new(&buf_without_type[..]); - let header = Header::depacketize(&mut read_cur).unwrap(); - - assert_eq!( - result, - StreamHeaderProcessResult::Success(StreamHeader::Track(header)) - ); - } - #[tokio::test] async fn subgroup_stream_header_success() { - let data_stream_type = DataStreamType::StreamHeaderSubgroup; + let data_stream_type = DataStreamType::SubgroupHeader; let bytes_array = [ 0, // Subscribe ID (i) 1, // Track Alias (i) @@ -186,38 +140,12 @@ mod tests { let mut read_cur = Cursor::new(&buf_without_type[..]); let header = subgroup_stream::Header::depacketize(&mut read_cur).unwrap(); - assert_eq!( - result, - StreamHeaderProcessResult::Success(StreamHeader::Subgroup(header)) - ); - } - - #[tokio::test] - async fn track_stream_header_continue_incomplete_message() { - let data_stream_type = DataStreamType::StreamHeaderTrack; - let bytes_array = [ - 0, // Group ID (i) - 1, // Object ID (i) - ]; - let mut buf = BytesMut::with_capacity(bytes_array.len() + 8); - buf.extend(write_variable_integer(data_stream_type as u64)); - buf.extend_from_slice(&bytes_array); - - let senders_mock = senders::test_helper_fn::create_senders_mock(); - let upstream_session_id = 0; - - let mut client = MOQTClient::new(upstream_session_id, senders_mock); - client.update_status(MOQTClientStatus::SetUp); - let client = Arc::new(Mutex::new(client)); - - let result = try_read_header(&mut buf, client).await; - - assert_eq!(result, StreamHeaderProcessResult::Continue); + assert_eq!(result, SubgroupStreamHeaderProcessResult::Success(header)); } #[tokio::test] async fn subgroup_stream_header_continue_incomplete_message() { - let data_stream_type = DataStreamType::StreamHeaderSubgroup; + let data_stream_type = DataStreamType::SubgroupHeader; let bytes_array = [ 0, // Object ID (i) ]; @@ -234,7 +162,7 @@ mod tests { let result = try_read_header(&mut buf, client).await; - assert_eq!(result, StreamHeaderProcessResult::Continue); + assert_eq!(result, SubgroupStreamHeaderProcessResult::Continue); } } } diff --git a/moqt-server/src/modules/message_handlers/subgroup_stream_object.rs b/moqt-server/src/modules/message_handlers/subgroup_stream_object.rs new file mode 100644 index 00000000..739ec853 --- /dev/null +++ b/moqt-server/src/modules/message_handlers/subgroup_stream_object.rs @@ -0,0 +1,96 @@ +use bytes::{Buf, BytesMut}; +use moqt_core::messages::data_streams::{subgroup_stream, DataStreams}; +use std::io::Cursor; + +#[derive(Debug, PartialEq)] +pub enum SubgroupStreamObjectProcessResult { + Success(subgroup_stream::Object), + Continue, +} + +pub async fn try_read_object(buf: &mut BytesMut) -> SubgroupStreamObjectProcessResult { + let payload_length = buf.len(); + tracing::trace!("stream_object_handler! {}", payload_length); + + // Check if the data is exist + if payload_length == 0 { + return SubgroupStreamObjectProcessResult::Continue; + } + + let mut read_cur = Cursor::new(&buf[..]); + + match subgroup_stream::Object::depacketize(&mut read_cur) { + Ok(stream_object) => { + buf.advance(read_cur.position() as usize); + SubgroupStreamObjectProcessResult::Success(stream_object) + } + Err(_err) => { + // TODO: `buffer does not have enough length` is not error. we want to change it to `Continue` + // tracing::info!("{:#?}", err); + // Reset the cursor position because data for an object has not yet arrived + read_cur.set_position(0); + SubgroupStreamObjectProcessResult::Continue + } + } +} + +#[cfg(test)] +mod tests { + mod success { + use crate::modules::message_handlers::subgroup_stream_object::{ + try_read_object, SubgroupStreamObjectProcessResult, + }; + use bytes::BytesMut; + use moqt_core::messages::data_streams::{subgroup_stream, DataStreams}; + use std::io::Cursor; + + #[tokio::test] + async fn stream_object_subgroup_success() { + let bytes_array = [ + 0, // Object ID (i) + 0, // Extension Header Length (i) + 3, // Object Payload Length (i) + 0, 1, 2, // Object Payload (..) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + let buf_clone = buf.clone(); + + let result = try_read_object(&mut buf).await; + + let mut read_cur = Cursor::new(&buf_clone[..]); + let object = subgroup_stream::Object::depacketize(&mut read_cur).unwrap(); + + assert_eq!(result, SubgroupStreamObjectProcessResult::Success(object)); + } + + #[tokio::test] + async fn stream_object_subgroup_continue_insufficient_payload() { + let bytes_array = [ + 0, // Object ID (i) + 0, // Extension Headers Length (i) + 50, // Object Payload Length (i) + 0, 1, 2, // Object Payload (..) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + + let result = try_read_object(&mut buf).await; + + assert_eq!(result, SubgroupStreamObjectProcessResult::Continue); + } + + #[tokio::test] + async fn stream_object_subgroup_continue_incomplete_message() { + let bytes_array = [ + 0, // Object ID (i) + ]; + let mut buf = BytesMut::with_capacity(bytes_array.len()); + buf.extend_from_slice(&bytes_array); + + let result = try_read_object(&mut buf).await; + + assert_eq!(result, SubgroupStreamObjectProcessResult::Continue); + } + } +} diff --git a/moqt-server/src/modules/moqt_client.rs b/moqt-server/src/modules/moqt_client.rs index 510540f2..6f0f8472 100644 --- a/moqt-server/src/modules/moqt_client.rs +++ b/moqt-server/src/modules/moqt_client.rs @@ -1,6 +1,4 @@ use super::server_processes::senders::Senders; -use anyhow::{bail, Ok, Result}; -use moqt_core::messages::control_messages::setup_parameters::RoleCase; use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -13,7 +11,6 @@ pub enum MOQTClientStatus { pub struct MOQTClient { id: usize, status: MOQTClientStatus, - role: Option, senders: Arc, } @@ -23,7 +20,6 @@ impl MOQTClient { MOQTClient { id, status: MOQTClientStatus::Connected, - role: None, senders, } } @@ -33,21 +29,10 @@ impl MOQTClient { pub fn status(&self) -> MOQTClientStatus { self.status } - pub fn role(&self) -> Option { - self.role - } pub fn update_status(&mut self, new_status: MOQTClientStatus) { self.status = new_status; } - pub fn set_role(&mut self, new_role: RoleCase) -> Result<()> { - if self.role.is_some() { - bail!("Client's role is already set."); - } - self.role = Some(new_role); - - Ok(()) - } pub fn senders(&self) -> Arc { self.senders.clone() diff --git a/moqt-server/src/modules/object_cache_storage/cache.rs b/moqt-server/src/modules/object_cache_storage/cache.rs index 64e4b9fe..f74f5ee7 100644 --- a/moqt-server/src/modules/object_cache_storage/cache.rs +++ b/moqt-server/src/modules/object_cache_storage/cache.rs @@ -1,10 +1,8 @@ pub(crate) mod datagram; pub(crate) mod subgroup_stream; -pub(crate) mod track_stream; use datagram::DatagramCache; use subgroup_stream::SubgroupStreamsCache; -use track_stream::TrackStreamCache; pub(crate) type CacheId = usize; type GroupId = u64; @@ -37,6 +35,5 @@ impl CacheKey { #[derive(Clone)] pub(crate) enum Cache { Datagram(DatagramCache), - TrackStream(TrackStreamCache), SubgroupStream(SubgroupStreamsCache), } diff --git a/moqt-server/src/modules/object_cache_storage/cache/datagram.rs b/moqt-server/src/modules/object_cache_storage/cache/datagram.rs index 036cf7bb..65c2573b 100644 --- a/moqt-server/src/modules/object_cache_storage/cache/datagram.rs +++ b/moqt-server/src/modules/object_cache_storage/cache/datagram.rs @@ -1,11 +1,11 @@ use super::CacheId; -use moqt_core::messages::data_streams::datagram; +use moqt_core::messages::data_streams::DatagramObject; use std::time::Duration; use ttl_cache::TtlCache; #[derive(Clone)] pub(crate) struct DatagramCache { - objects: TtlCache, + objects: TtlCache, next_cache_id: CacheId, } @@ -19,19 +19,29 @@ impl DatagramCache { } } - pub(crate) fn insert_object(&mut self, object: datagram::Object, duration: u64) { + pub(crate) fn insert_object(&mut self, object: DatagramObject, duration: u64) { let ttl = Duration::from_millis(duration); self.objects.insert(self.next_cache_id, object, ttl); self.next_cache_id += 1; } - pub(crate) fn get_absolute_object_with_cache_id( + pub(crate) fn get_object( &mut self, group_id: u64, object_id: u64, - ) -> Option<(CacheId, datagram::Object)> { + ) -> Option<(CacheId, DatagramObject)> { self.objects.iter().find_map(|(k, v)| { - if v.group_id() == group_id && v.object_id() == object_id { + let g_id = match v { + DatagramObject::ObjectDatagram(obj) => obj.group_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.group_id(), + }; + + let o_id = match v { + DatagramObject::ObjectDatagram(obj) => obj.object_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.object_id(), + }; + + if g_id == group_id && o_id == object_id { Some((*k, v.clone())) } else { None @@ -39,10 +49,10 @@ impl DatagramCache { }) } - pub(crate) fn get_next_object_with_cache_id( + pub(crate) fn get_next_object( &mut self, cache_id: CacheId, - ) -> Option<(CacheId, datagram::Object)> { + ) -> Option<(CacheId, DatagramObject)> { let next_cache_id = cache_id + 1; self.objects.iter().find_map(|(k, v)| { if *k == next_cache_id { @@ -53,52 +63,74 @@ impl DatagramCache { }) } - pub(crate) fn get_latest_group_with_cache_id(&mut self) -> Option<(CacheId, datagram::Object)> { + pub(crate) fn get_latest_group(&mut self) -> Option<(CacheId, DatagramObject)> { let latest_group_id = self .objects .iter() .last() - .map(|(_, v)| v.group_id()) + .map(|(_, v)| match v { + DatagramObject::ObjectDatagram(obj) => obj.group_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.group_id(), + }) .unwrap(); let latest_group = self.objects.iter().filter_map(|(k, v)| { - if v.group_id() == latest_group_id { + let g_id = match v { + DatagramObject::ObjectDatagram(obj) => obj.group_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.group_id(), + }; + if g_id == latest_group_id { Some((*k, v.clone())) } else { None } }); - latest_group.min_by_key(|(k, v)| (v.object_id(), *k)) + latest_group.min_by_key(|(k, v)| { + let o_id = match v { + DatagramObject::ObjectDatagram(obj) => obj.object_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.object_id(), + }; + (o_id, *k) + }) } - pub(crate) fn get_latest_object_with_cache_id( - &mut self, - ) -> Option<(CacheId, datagram::Object)> { + pub(crate) fn get_latest_object(&mut self) -> Option<(CacheId, DatagramObject)> { self.objects.iter().last().map(|(k, v)| (*k, v.clone())) } - pub(crate) fn get_largest_group_id(&mut self) -> u64 { + pub(crate) fn get_largest_group_id(&mut self) -> Option { self.objects .iter() - .map(|(_, v)| v.group_id()) + .map(|(_, v)| match v { + DatagramObject::ObjectDatagram(obj) => obj.group_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.group_id(), + }) .max() - .unwrap() } - pub(crate) fn get_largest_object_id(&mut self) -> u64 { - let largest_group_id = self.get_largest_group_id(); + pub(crate) fn get_largest_object_id(&mut self) -> Option { + let largest_group_id = self.get_largest_group_id()?; self.objects .iter() .filter_map(|(_, v)| { - if v.group_id() == largest_group_id { - Some(v.object_id()) + let g_id = match v { + DatagramObject::ObjectDatagram(obj) => obj.group_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.group_id(), + }; + + let o_id = match v { + DatagramObject::ObjectDatagram(obj) => obj.object_id(), + DatagramObject::ObjectDatagramStatus(obj) => obj.object_id(), + }; + + if g_id == largest_group_id { + Some(o_id) } else { None } }) .max() - .unwrap() } } diff --git a/moqt-server/src/modules/object_cache_storage/cache/subgroup_stream.rs b/moqt-server/src/modules/object_cache_storage/cache/subgroup_stream.rs index 38acdd7a..adc1af12 100644 --- a/moqt-server/src/modules/object_cache_storage/cache/subgroup_stream.rs +++ b/moqt-server/src/modules/object_cache_storage/cache/subgroup_stream.rs @@ -15,7 +15,7 @@ impl SubgroupStreamsCache { Self { streams } } - pub(crate) fn add_subgroup_stream( + pub(crate) fn set_subgroup_stream( &mut self, group_id: u64, subgroup_id: u64, @@ -48,7 +48,7 @@ impl SubgroupStreamsCache { .unwrap() } - pub(crate) fn get_absolute_object_with_cache_id( + pub(crate) fn get_object( &mut self, group_id: u64, subgroup_id: u64, @@ -56,10 +56,10 @@ impl SubgroupStreamsCache { ) -> Option<(CacheId, subgroup_stream::Object)> { let subgroup_stream_id = (group_id, subgroup_id); let subgroup_stream_cache = self.streams.get_mut(&subgroup_stream_id).unwrap(); - subgroup_stream_cache.get_absolute_object_with_cache_id(object_id) + subgroup_stream_cache.get_object(object_id) } - pub(crate) fn get_next_object_with_cache_id( + pub(crate) fn get_next_object( &mut self, group_id: u64, subgroup_id: u64, @@ -67,43 +67,53 @@ impl SubgroupStreamsCache { ) -> Option<(CacheId, subgroup_stream::Object)> { let subgroup_stream_id = (group_id, subgroup_id); let subgroup_stream_cache = self.streams.get_mut(&subgroup_stream_id).unwrap(); - subgroup_stream_cache.get_next_object_with_cache_id(cache_id) + subgroup_stream_cache.get_next_object(cache_id) } - pub(crate) fn get_first_object_with_cache_id( + pub(crate) fn get_first_object( &mut self, group_id: u64, subgroup_id: u64, ) -> Option<(CacheId, subgroup_stream::Object)> { let subgroup_stream_id = (group_id, subgroup_id); let subgroup_stream_cache = self.streams.get_mut(&subgroup_stream_id).unwrap(); - subgroup_stream_cache.get_first_object_with_cache_id() + subgroup_stream_cache.get_first_object() } - pub(crate) fn get_largest_group_id(&mut self) -> u64 { - self.streams.iter().map(|((gid, _), _)| *gid).max().unwrap() + pub(crate) fn get_latest_object( + &mut self, + group_id: u64, + subgroup_id: u64, + ) -> Option<(CacheId, subgroup_stream::Object)> { + let subgroup_stream_id = (group_id, subgroup_id); + let subgroup_stream_cache = self.streams.get_mut(&subgroup_stream_id).unwrap(); + subgroup_stream_cache.get_latest_object() } - pub(crate) fn get_largest_object_id(&mut self) -> u64 { - let largest_group_id = self.get_largest_group_id(); - let largest_subgroup_id = self - .streams - .iter() - .filter_map(|((gid, sgid), _)| { - if *gid == largest_group_id { - Some(*sgid) - } else { - None - } - }) - .max() - .unwrap(); - let subgroup_stream_id = (largest_group_id, largest_subgroup_id); + pub(crate) fn get_largest_group_id(&mut self) -> Option { + self.streams.iter().map(|((gid, _), _)| *gid).max() + } - self.streams - .get_mut(&subgroup_stream_id) - .unwrap() - .get_largest_object_id() + pub(crate) fn get_largest_object_id(&mut self) -> Option { + let largest_group_id = self.get_largest_group_id()?; + + let subgroup_ids = self.get_all_subgroup_ids(largest_group_id); + + let mut largest_object_id = None; + for subgroup_id in subgroup_ids.iter().rev() { + let subgroup_stream_id = (largest_group_id, *subgroup_id); + let object_id = self + .streams + .get_mut(&subgroup_stream_id) + .unwrap() + .get_largest_object_id(); + + if largest_object_id.is_none() || object_id > largest_object_id.unwrap() { + largest_object_id = Some(object_id); + } + } + + largest_object_id } pub(crate) fn get_all_subgroup_ids(&mut self, group_id: u64) -> Vec { @@ -154,10 +164,7 @@ impl SubgroupStreamCache { self.header.clone() } - fn get_absolute_object_with_cache_id( - &mut self, - object_id: u64, - ) -> Option<(CacheId, subgroup_stream::Object)> { + fn get_object(&mut self, object_id: u64) -> Option<(CacheId, subgroup_stream::Object)> { self.objects.iter().find_map(|(k, v)| { if v.object_id() == object_id { Some((*k, v.clone())) @@ -167,10 +174,7 @@ impl SubgroupStreamCache { }) } - fn get_next_object_with_cache_id( - &mut self, - cache_id: CacheId, - ) -> Option<(CacheId, subgroup_stream::Object)> { + fn get_next_object(&mut self, cache_id: CacheId) -> Option<(CacheId, subgroup_stream::Object)> { let next_cache_id = cache_id + 1; self.objects.iter().find_map(|(k, v)| { if *k == next_cache_id { @@ -181,10 +185,14 @@ impl SubgroupStreamCache { }) } - fn get_first_object_with_cache_id(&mut self) -> Option<(CacheId, subgroup_stream::Object)> { + fn get_first_object(&mut self) -> Option<(CacheId, subgroup_stream::Object)> { self.objects.iter().next().map(|(k, v)| (*k, v.clone())) } + fn get_latest_object(&mut self) -> Option<(CacheId, subgroup_stream::Object)> { + self.objects.iter().last().map(|(k, v)| (*k, v.clone())) + } + fn get_largest_object_id(&mut self) -> u64 { self.objects .iter() diff --git a/moqt-server/src/modules/object_cache_storage/cache/track_stream.rs b/moqt-server/src/modules/object_cache_storage/cache/track_stream.rs deleted file mode 100644 index 99fe9861..00000000 --- a/moqt-server/src/modules/object_cache_storage/cache/track_stream.rs +++ /dev/null @@ -1,112 +0,0 @@ -use super::CacheId; -use moqt_core::messages::data_streams::track_stream; -use std::time::Duration; -use ttl_cache::TtlCache; - -#[derive(Clone)] -pub(crate) struct TrackStreamCache { - header: track_stream::Header, - objects: TtlCache, - next_cache_id: CacheId, -} - -impl TrackStreamCache { - pub(crate) fn new(header: track_stream::Header, max_store_size: usize) -> Self { - let objects = TtlCache::new(max_store_size); - - Self { - header, - objects, - next_cache_id: 0, - } - } - - pub(crate) fn insert_object(&mut self, object: track_stream::Object, duration: u64) { - let ttl = Duration::from_millis(duration); - self.objects.insert(self.next_cache_id, object, ttl); - self.next_cache_id += 1; - } - - pub(crate) fn get_header(&self) -> track_stream::Header { - self.header.clone() - } - - pub(crate) fn get_absolute_object_with_cache_id( - &mut self, - group_id: u64, - object_id: u64, - ) -> Option<(CacheId, track_stream::Object)> { - self.objects.iter().find_map(|(k, v)| { - if v.group_id() == group_id && v.object_id() == object_id { - Some((*k, v.clone())) - } else { - None - } - }) - } - - pub(crate) fn get_next_object_with_cache_id( - &mut self, - cache_id: CacheId, - ) -> Option<(CacheId, track_stream::Object)> { - let next_cache_id = cache_id + 1; - self.objects.iter().find_map(|(k, v)| { - if *k == next_cache_id { - Some((*k, v.clone())) - } else { - None - } - }) - } - - pub(crate) fn get_latest_group_with_cache_id( - &mut self, - ) -> Option<(CacheId, track_stream::Object)> { - let latest_group_id = self - .objects - .iter() - .last() - .map(|(_, v)| v.group_id()) - .unwrap(); - - let latest_group = self.objects.iter().filter_map(|(k, v)| { - if v.group_id() == latest_group_id { - Some((*k, v.clone())) - } else { - None - } - }); - - latest_group.min_by_key(|(k, v)| (v.object_id(), *k)) - } - - pub(crate) fn get_latest_object_with_cache_id( - &mut self, - ) -> Option<(CacheId, track_stream::Object)> { - self.objects.iter().last().map(|(k, v)| (*k, v.clone())) - } - - pub(crate) fn get_largest_group_id(&mut self) -> u64 { - self.objects - .iter() - .map(|(_, v)| v.group_id()) - .max() - .unwrap() - } - - pub(crate) fn get_largest_object_id(&mut self) -> u64 { - let largest_group_id = self.get_largest_group_id(); - - self.objects - .iter() - .filter_map(|(_, v)| { - if v.group_id() == largest_group_id { - Some(v.object_id()) - } else { - None - } - }) - .max() - .unwrap() - } -} diff --git a/moqt-server/src/modules/object_cache_storage/commands.rs b/moqt-server/src/modules/object_cache_storage/commands.rs index 38c0d506..dd000560 100644 --- a/moqt-server/src/modules/object_cache_storage/commands.rs +++ b/moqt-server/src/modules/object_cache_storage/commands.rs @@ -1,6 +1,6 @@ use super::cache::{CacheId, CacheKey, SubgroupId}; use anyhow::Result; -use moqt_core::messages::data_streams::{datagram, subgroup_stream, track_stream}; +use moqt_core::messages::data_streams::{subgroup_stream, DatagramObject}; use tokio::sync::oneshot; #[derive(Debug)] @@ -9,11 +9,6 @@ pub(crate) enum ObjectCacheStorageCommand { cache_key: CacheKey, resp: oneshot::Sender>, }, - CreateTrackStreamCache { - cache_key: CacheKey, - header: track_stream::Header, - resp: oneshot::Sender>, - }, CreateSubgroupStreamCache { cache_key: CacheKey, group_id: u64, @@ -21,14 +16,10 @@ pub(crate) enum ObjectCacheStorageCommand { header: subgroup_stream::Header, resp: oneshot::Sender>, }, - ExistDatagramCache { + HasDatagramCache { cache_key: CacheKey, resp: oneshot::Sender>, }, - GetTrackStreamHeader { - cache_key: CacheKey, - resp: oneshot::Sender>, - }, GetSubgroupStreamHeader { cache_key: CacheKey, group_id: u64, @@ -37,13 +28,7 @@ pub(crate) enum ObjectCacheStorageCommand { }, SetDatagramObject { cache_key: CacheKey, - datagram_object: datagram::Object, - duration: u64, - resp: oneshot::Sender>, - }, - SetTrackStreamObject { - cache_key: CacheKey, - track_stream_object: track_stream::Object, + datagram_object: DatagramObject, duration: u64, resp: oneshot::Sender>, }, @@ -55,19 +40,13 @@ pub(crate) enum ObjectCacheStorageCommand { duration: u64, resp: oneshot::Sender>, }, - GetAbsoluteDatagramObject { - cache_key: CacheKey, - group_id: u64, - object_id: u64, - resp: oneshot::Sender>>, - }, - GetAbsoluteTrackStreamObject { + GetDatagramObject { cache_key: CacheKey, group_id: u64, object_id: u64, - resp: oneshot::Sender>>, + resp: oneshot::Sender>>, }, - GetAbsoluteSubgroupStreamObject { + GetSubgroupStreamObject { cache_key: CacheKey, group_id: u64, subgroup_id: u64, @@ -77,12 +56,7 @@ pub(crate) enum ObjectCacheStorageCommand { GetNextDatagramObject { cache_key: CacheKey, cache_id: CacheId, - resp: oneshot::Sender>>, - }, - GetNextTrackStreamObject { - cache_key: CacheKey, - cache_id: CacheId, - resp: oneshot::Sender>>, + resp: oneshot::Sender>>, }, GetNextSubgroupStreamObject { cache_key: CacheKey, @@ -93,19 +67,11 @@ pub(crate) enum ObjectCacheStorageCommand { }, GetLatestDatagramObject { cache_key: CacheKey, - resp: oneshot::Sender>>, - }, - GetLatestTrackStreamObject { - cache_key: CacheKey, - resp: oneshot::Sender>>, + resp: oneshot::Sender>>, }, GetLatestDatagramGroup { cache_key: CacheKey, - resp: oneshot::Sender>>, - }, - GetLatestTrackStreamGroup { - cache_key: CacheKey, - resp: oneshot::Sender>>, + resp: oneshot::Sender>>, }, // Since current Forwarder is generated for each Group, // LatestGroup is never used for SubgroupCache. @@ -116,6 +82,14 @@ pub(crate) enum ObjectCacheStorageCommand { subgroup_id: u64, resp: oneshot::Sender>>, }, + // TODO: Remove LatestGroup since it is not exist in the draft-10 + #[allow(dead_code)] + GetLatestSubgroupStreamObject { + cache_key: CacheKey, + group_id: u64, + subgroup_id: u64, + resp: oneshot::Sender>>, + }, GetAllSubgroupIds { cache_key: CacheKey, group_id: u64, @@ -123,11 +97,11 @@ pub(crate) enum ObjectCacheStorageCommand { }, GetLargestGroupId { cache_key: CacheKey, - resp: oneshot::Sender>, + resp: oneshot::Sender>>, }, GetLargestObjectId { cache_key: CacheKey, - resp: oneshot::Sender>, + resp: oneshot::Sender>>, }, DeleteClient { session_id: usize, diff --git a/moqt-server/src/modules/object_cache_storage/storage.rs b/moqt-server/src/modules/object_cache_storage/storage.rs index 4b78491b..5b5639c5 100644 --- a/moqt-server/src/modules/object_cache_storage/storage.rs +++ b/moqt-server/src/modules/object_cache_storage/storage.rs @@ -1,7 +1,6 @@ use super::commands::ObjectCacheStorageCommand; use crate::modules::object_cache_storage::cache::{ - datagram::DatagramCache, subgroup_stream::SubgroupStreamsCache, track_stream::TrackStreamCache, - Cache, CacheKey, + datagram::DatagramCache, subgroup_stream::SubgroupStreamsCache, Cache, CacheKey, }; use std::collections::HashMap; use tokio::sync::mpsc; @@ -26,19 +25,6 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { - let track_stream_cache = TrackStreamCache::new(header, max_cache_size); - let cache = Cache::TrackStream(track_stream_cache); - - // Insert the TrackStreamCache into the ObjectCacheStorage - storage.insert(cache_key.clone(), cache); - - resp.send(Ok(())).unwrap(); - } ObjectCacheStorageCommand::CreateSubgroupStreamCache { cache_key, group_id, @@ -57,7 +43,7 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { + ObjectCacheStorageCommand::HasDatagramCache { cache_key, resp } => { let cache = storage.get(&cache_key); match cache { Some(Cache::Datagram(_)) => { @@ -77,20 +63,6 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { - let cache = storage.get(&cache_key); - let track_stream_cache = match cache { - Some(Cache::TrackStream(track_stream_cache)) => track_stream_cache, - _ => { - resp.send(Err(anyhow::anyhow!("track stream cache not found"))) - .unwrap(); - continue; - } - }; - - let header = track_stream_cache.get_header(); - resp.send(Ok(header)).unwrap(); - } ObjectCacheStorageCommand::GetSubgroupStreamHeader { cache_key, group_id, @@ -129,25 +101,6 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { - let cache = storage.get_mut(&cache_key); - let track_stream_cache = match cache { - Some(Cache::TrackStream(track_stream_cache)) => track_stream_cache, - _ => { - resp.send(Err(anyhow::anyhow!("track stream cache not found"))) - .unwrap(); - continue; - } - }; - - track_stream_cache.insert_object(track_stream_object, duration); - resp.send(Ok(())).unwrap(); - } ObjectCacheStorageCommand::SetSubgroupStreamObject { cache_key, group_id, @@ -174,7 +127,7 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { - let cache = storage.get_mut(&cache_key); - let track_stream_cache = match cache { - Some(Cache::TrackStream(track_stream_cache)) => track_stream_cache, - _ => { - resp.send(Err(anyhow::anyhow!("track stream cache not found"))) - .unwrap(); - continue; - } - }; - - let object_with_cache_id = - track_stream_cache.get_absolute_object_with_cache_id(group_id, object_id); + let object_with_cache_id = datagram_cache.get_object(group_id, object_id); resp.send(Ok(object_with_cache_id)).unwrap(); } - ObjectCacheStorageCommand::GetAbsoluteSubgroupStreamObject { + ObjectCacheStorageCommand::GetSubgroupStreamObject { cache_key, group_id, subgroup_id, @@ -231,8 +163,8 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { - let cache = storage.get_mut(&cache_key); - let track_stream_cache = match cache { - Some(Cache::TrackStream(track_stream_cache)) => track_stream_cache, - _ => { - resp.send(Err(anyhow::anyhow!("track stream cache not found"))) - .unwrap(); - continue; - } - }; - - let object_with_cache_id = - track_stream_cache.get_next_object_with_cache_id(cache_id); + let object_with_cache_id = datagram_cache.get_next_object(cache_id); resp.send(Ok(object_with_cache_id)).unwrap(); } ObjectCacheStorageCommand::GetNextSubgroupStreamObject { @@ -289,11 +202,8 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { @@ -307,21 +217,21 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { + ObjectCacheStorageCommand::GetLatestDatagramObject { cache_key, resp } => { let cache = storage.get_mut(&cache_key); - let track_stream_cache = match cache { - Some(Cache::TrackStream(track_stream_cache)) => track_stream_cache, + let datagram_cache = match cache { + Some(Cache::Datagram(datagram_cache)) => datagram_cache, _ => { - resp.send(Err(anyhow::anyhow!("track stream cache not found"))) + resp.send(Err(anyhow::anyhow!("datagram cache not found"))) .unwrap(); continue; } }; - let object_with_cache_id = track_stream_cache.get_latest_group_with_cache_id(); + let object_with_cache_id = datagram_cache.get_latest_object(); resp.send(Ok(object_with_cache_id)).unwrap(); } ObjectCacheStorageCommand::GetFirstSubgroupStreamObject { @@ -341,35 +251,27 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { - let cache = storage.get_mut(&cache_key); - let datagram_cache = match cache { - Some(Cache::Datagram(datagram_cache)) => datagram_cache, - _ => { - resp.send(Err(anyhow::anyhow!("datagram cache not found"))) - .unwrap(); - continue; - } - }; - - let object_with_cache_id = datagram_cache.get_latest_object_with_cache_id(); - resp.send(Ok(object_with_cache_id)).unwrap(); - } - ObjectCacheStorageCommand::GetLatestTrackStreamObject { cache_key, resp } => { + ObjectCacheStorageCommand::GetLatestSubgroupStreamObject { + cache_key, + group_id, + subgroup_id, + resp, + } => { let cache = storage.get_mut(&cache_key); - let track_stream_cache = match cache { - Some(Cache::TrackStream(track_stream_cache)) => track_stream_cache, + let subgroup_streams_cache = match cache { + Some(Cache::SubgroupStream(subgroup_stream_cache)) => subgroup_stream_cache, _ => { - resp.send(Err(anyhow::anyhow!("track stream cache not found"))) + resp.send(Err(anyhow::anyhow!("subgroup stream cache not found"))) .unwrap(); continue; } }; - let object_with_cache_id = track_stream_cache.get_latest_object_with_cache_id(); + let object_with_cache_id = + subgroup_streams_cache.get_latest_object(group_id, subgroup_id); resp.send(Ok(object_with_cache_id)).unwrap(); } ObjectCacheStorageCommand::GetAllSubgroupIds { @@ -393,11 +295,8 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { let cache = storage.get_mut(&cache_key); if let Some(cache) = cache { - let largest_group_id: u64 = match cache { + let largest_group_id = match cache { Cache::Datagram(datagram_cache) => datagram_cache.get_largest_group_id(), - Cache::TrackStream(track_stream_cache) => { - track_stream_cache.get_largest_group_id() - } Cache::SubgroupStream(subgroup_stream_cache) => { subgroup_stream_cache.get_largest_group_id() } @@ -411,11 +310,8 @@ pub(crate) async fn object_cache_storage(rx: &mut mpsc::Receiver { let cache = storage.get_mut(&cache_key); if let Some(cache) = cache { - let largest_object_id: u64 = match cache { + let largest_object_id = match cache { Cache::Datagram(datagram_cache) => datagram_cache.get_largest_object_id(), - Cache::TrackStream(track_stream_cache) => { - track_stream_cache.get_largest_object_id() - } Cache::SubgroupStream(subgroup_stream_cache) => { subgroup_stream_cache.get_largest_object_id() } diff --git a/moqt-server/src/modules/object_cache_storage/wrapper.rs b/moqt-server/src/modules/object_cache_storage/wrapper.rs index f23ad79d..33ba543b 100644 --- a/moqt-server/src/modules/object_cache_storage/wrapper.rs +++ b/moqt-server/src/modules/object_cache_storage/wrapper.rs @@ -3,7 +3,7 @@ use super::{ commands::ObjectCacheStorageCommand, }; use anyhow::{bail, Result}; -use moqt_core::messages::data_streams::{datagram, subgroup_stream, track_stream}; +use moqt_core::messages::data_streams::{subgroup_stream, DatagramObject}; use tokio::sync::{mpsc, oneshot}; pub(crate) struct ObjectCacheStorageWrapper { @@ -33,29 +33,6 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn create_track_stream_cache( - &mut self, - cache_key: &CacheKey, - header: track_stream::Header, - ) -> Result<()> { - let (resp_tx, resp_rx) = oneshot::channel::>(); - - let cmd = ObjectCacheStorageCommand::CreateTrackStreamCache { - cache_key: cache_key.clone(), - header, - resp: resp_tx, - }; - - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - - match result { - Ok(_) => Ok(()), - Err(err) => bail!(err), - } - } - pub(crate) async fn create_subgroup_stream_cache( &mut self, cache_key: &CacheKey, @@ -86,7 +63,7 @@ impl ObjectCacheStorageWrapper { pub(crate) async fn exist_datagram_cache(&mut self, cache_key: &CacheKey) -> Result { let (resp_tx, resp_rx) = oneshot::channel::>(); - let cmd = ObjectCacheStorageCommand::ExistDatagramCache { + let cmd = ObjectCacheStorageCommand::HasDatagramCache { cache_key: cache_key.clone(), resp: resp_tx, }; @@ -101,27 +78,6 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_track_stream_header( - &mut self, - cache_key: &CacheKey, - ) -> Result { - let (resp_tx, resp_rx) = oneshot::channel::>(); - - let cmd = ObjectCacheStorageCommand::GetTrackStreamHeader { - cache_key: cache_key.clone(), - resp: resp_tx, - }; - - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - - match result { - Ok(header_cache) => Ok(header_cache), - Err(err) => bail!(err), - } - } - pub(crate) async fn get_subgroup_stream_header( &mut self, cache_key: &CacheKey, @@ -150,7 +106,7 @@ impl ObjectCacheStorageWrapper { pub(crate) async fn set_datagram_object( &mut self, cache_key: &CacheKey, - datagram_object: datagram::Object, + datagram_object: DatagramObject, duration: u64, ) -> Result<()> { let (resp_tx, resp_rx) = oneshot::channel::>(); @@ -172,31 +128,6 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn set_track_stream_object( - &mut self, - cache_key: &CacheKey, - track_stream_object: track_stream::Object, - duration: u64, - ) -> Result<()> { - let (resp_tx, resp_rx) = oneshot::channel::>(); - - let cmd = ObjectCacheStorageCommand::SetTrackStreamObject { - cache_key: cache_key.clone(), - track_stream_object, - duration, - resp: resp_tx, - }; - - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - - match result { - Ok(_) => Ok(()), - Err(err) => bail!(err), - } - } - pub(crate) async fn set_subgroup_stream_object( &mut self, cache_key: &CacheKey, @@ -231,36 +162,10 @@ impl ObjectCacheStorageWrapper { cache_key: &CacheKey, group_id: u64, object_id: u64, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); - - let cmd = ObjectCacheStorageCommand::GetAbsoluteDatagramObject { - cache_key: cache_key.clone(), - group_id, - object_id, - resp: resp_tx, - }; - - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - - match result { - Ok(object_cache) => Ok(object_cache), - Err(err) => bail!(err), - } - } - - pub(crate) async fn get_absolute_track_stream_object( - &mut self, - cache_key: &CacheKey, - group_id: u64, - object_id: u64, - ) -> Result> { - let (resp_tx, resp_rx) = - oneshot::channel::>>(); + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = ObjectCacheStorageCommand::GetAbsoluteTrackStreamObject { + let cmd = ObjectCacheStorageCommand::GetDatagramObject { cache_key: cache_key.clone(), group_id, object_id, @@ -287,7 +192,7 @@ impl ObjectCacheStorageWrapper { let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = ObjectCacheStorageCommand::GetAbsoluteSubgroupStreamObject { + let cmd = ObjectCacheStorageCommand::GetSubgroupStreamObject { cache_key: cache_key.clone(), group_id, subgroup_id, @@ -309,8 +214,8 @@ impl ObjectCacheStorageWrapper { &mut self, cache_key: &CacheKey, cache_id: usize, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); let cmd = ObjectCacheStorageCommand::GetNextDatagramObject { cache_key: cache_key.clone(), @@ -328,30 +233,6 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_next_track_stream_object( - &mut self, - cache_key: &CacheKey, - cache_id: usize, - ) -> Result> { - let (resp_tx, resp_rx) = - oneshot::channel::>>(); - - let cmd = ObjectCacheStorageCommand::GetNextTrackStreamObject { - cache_key: cache_key.clone(), - cache_id, - resp: resp_tx, - }; - - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - - match result { - Ok(object_cache) => Ok(object_cache), - Err(err) => bail!(err), - } - } - pub(crate) async fn get_next_subgroup_stream_object( &mut self, cache_key: &CacheKey, @@ -383,8 +264,8 @@ impl ObjectCacheStorageWrapper { pub(crate) async fn get_latest_datagram_object( &mut self, cache_key: &CacheKey, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); let cmd = ObjectCacheStorageCommand::GetLatestDatagramObject { cache_key: cache_key.clone(), @@ -401,33 +282,11 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_latest_track_stream_object( - &mut self, - cache_key: &CacheKey, - ) -> Result> { - let (resp_tx, resp_rx) = - oneshot::channel::>>(); - - let cmd = ObjectCacheStorageCommand::GetLatestTrackStreamObject { - cache_key: cache_key.clone(), - resp: resp_tx, - }; - - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - - match result { - Ok(object_cache) => Ok(object_cache), - Err(err) => bail!(err), - } - } - pub(crate) async fn get_latest_datagram_group( &mut self, cache_key: &CacheKey, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); let cmd = ObjectCacheStorageCommand::GetLatestDatagramGroup { cache_key: cache_key.clone(), @@ -444,15 +303,19 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_latest_track_stream_group( + pub(crate) async fn get_first_subgroup_stream_object( &mut self, cache_key: &CacheKey, - ) -> Result> { + group_id: u64, + subgroup_id: u64, + ) -> Result> { let (resp_tx, resp_rx) = - oneshot::channel::>>(); + oneshot::channel::>>(); - let cmd = ObjectCacheStorageCommand::GetLatestTrackStreamGroup { + let cmd = ObjectCacheStorageCommand::GetFirstSubgroupStreamObject { cache_key: cache_key.clone(), + group_id, + subgroup_id, resp: resp_tx, }; @@ -466,7 +329,8 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_first_subgroup_stream_object( + #[allow(dead_code)] + pub(crate) async fn get_latest_subgroup_stream_object( &mut self, cache_key: &CacheKey, group_id: u64, @@ -475,7 +339,7 @@ impl ObjectCacheStorageWrapper { let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = ObjectCacheStorageCommand::GetFirstSubgroupStreamObject { + let cmd = ObjectCacheStorageCommand::GetLatestSubgroupStreamObject { cache_key: cache_key.clone(), group_id, subgroup_id, @@ -515,8 +379,11 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_largest_group_id(&mut self, cache_key: &CacheKey) -> Result { - let (resp_tx, resp_rx) = oneshot::channel::>(); + pub(crate) async fn get_largest_group_id( + &mut self, + cache_key: &CacheKey, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); let cmd = ObjectCacheStorageCommand::GetLargestGroupId { cache_key: cache_key.clone(), @@ -533,8 +400,11 @@ impl ObjectCacheStorageWrapper { } } - pub(crate) async fn get_largest_object_id(&mut self, cache_key: &CacheKey) -> Result { - let (resp_tx, resp_rx) = oneshot::channel::>(); + pub(crate) async fn get_largest_object_id( + &mut self, + cache_key: &CacheKey, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); let cmd = ObjectCacheStorageCommand::GetLargestObjectId { cache_key: cache_key.clone(), @@ -546,7 +416,7 @@ impl ObjectCacheStorageWrapper { let result = resp_rx.await.unwrap(); match result { - Ok(group_id) => Ok(group_id), + Ok(object_id) => Ok(object_id), Err(err) => bail!(err), } } @@ -576,7 +446,9 @@ mod success { cache::CacheKey, commands::ObjectCacheStorageCommand, storage::object_cache_storage, wrapper::ObjectCacheStorageWrapper, }; - use moqt_core::messages::data_streams::{datagram, subgroup_stream, track_stream}; + use moqt_core::messages::data_streams::{ + datagram, datagram_status, object_status::ObjectStatus, subgroup_stream, DatagramObject, + }; use tokio::sync::mpsc; #[tokio::test] @@ -596,30 +468,6 @@ mod success { assert!(result.is_ok()); } - #[tokio::test] - async fn create_track_stream_cache() { - let session_id = 0; - let subscribe_id = 1; - let cache_key = CacheKey::new(session_id, subscribe_id); - let track_alias = 2; - let publisher_priority = 3; - - let track_stream_header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - // start object cache storage thread - let (cache_tx, mut cache_rx) = mpsc::channel::(1024); - tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); - - let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - - let result = object_cache_storage - .create_track_stream_cache(&cache_key, track_stream_header) - .await; - - assert!(result.is_ok()); - } - #[tokio::test] async fn create_subgroup_stream_cache() { let session_id = 0; @@ -630,14 +478,9 @@ mod success { let subgroup_id = 4; let publisher_priority = 5; - let subgroup_stream_header = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, - ) - .unwrap(); + let subgroup_stream_header = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -676,34 +519,6 @@ mod success { assert!(result.unwrap()); } - #[tokio::test] - async fn get_track_stream_header() { - let session_id = 0; - let subscribe_id = 1; - let track_alias = 2; - let publisher_priority = 3; - let cache_key = CacheKey::new(session_id, subscribe_id); - let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - // start object cache storage thread - let (cache_tx, mut cache_rx) = mpsc::channel::(1024); - tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); - - let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - - let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header.clone()) - .await; - - let result = object_cache_storage - .get_track_stream_header(&cache_key) - .await; - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), header); - } - #[tokio::test] async fn get_subgroup_stream_header() { let session_id = 0; @@ -713,14 +528,9 @@ mod success { let subgroup_id = 4; let publisher_priority = 5; let cache_key = CacheKey::new(session_id, subscribe_id); - let header = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, - ) - .unwrap(); + let header = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -749,19 +559,19 @@ mod success { let track_alias = 3; let group_id = 4; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let object_payload = vec![1, 2, 3, 4]; let duration = 1000; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -778,19 +588,27 @@ mod success { } #[tokio::test] - async fn set_track_stream_object() { + async fn set_datagram_object_status() { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); let object_id = 2; - let group_id = 3; - let publisher_priority = 4; - let object_status = None; - let object_payload = vec![1, 2, 3, 4]; - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload).unwrap(); - let header = track_stream::Header::new(subscribe_id, group_id, publisher_priority).unwrap(); + let track_alias = 3; + let group_id = 4; + let publisher_priority = 5; + let extension_headers = vec![]; + let object_status = ObjectStatus::EndOfGroup; let duration = 1000; + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers, + object_status, + ) + .unwrap(); + let datagram_object = DatagramObject::ObjectDatagramStatus(datagram_object); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -798,11 +616,9 @@ mod success { let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) - .await; + let _ = object_cache_storage.create_datagram_cache(&cache_key).await; let result = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) + .set_datagram_object(&cache_key, datagram_object, duration) .await; assert!(result.is_ok()); @@ -818,18 +634,19 @@ mod success { let group_id = 3; let subgroup_id = 4; let publisher_priority = 5; + let extension_headers = vec![]; let object_status = None; let object_payload = vec![1, 2, 3, 4]; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); - let header = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + object_payload, ) .unwrap(); + let header = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); let duration = 1000; // start object cache storage thread @@ -862,7 +679,7 @@ mod success { let track_alias = 3; let group_id = 4; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; // start object cache storage thread @@ -878,15 +695,15 @@ mod success { let object_id = i as u64; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers.clone(), object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); let _ = object_cache_storage .set_datagram_object(&cache_key, datagram_object, duration) @@ -897,15 +714,15 @@ mod success { let expected_cache_id = 5; let expected_object_payload = vec![5, 6, 7, 8]; let expected_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers, expected_object_payload, ) .unwrap(); + let expected_object = DatagramObject::ObjectDatagram(expected_object); let result = object_cache_storage .get_absolute_datagram_object(&cache_key, group_id, object_id) @@ -919,17 +736,20 @@ mod success { } #[tokio::test] - async fn get_absolute_track_stream_object() { + async fn get_absolute_subgroup_stream_object() { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; let group_id = 4; - let publisher_priority = 5; + let subgroup_id = 5; + let publisher_priority = 6; + let extension_headers = vec![]; let object_status = None; let duration = 1000; let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -937,31 +757,45 @@ mod success { let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) + .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, header) .await; for i in 0..10 { let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; let object_id = i as u64; - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers.clone(), + object_status, + object_payload, + ) + .unwrap(); let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) + .set_subgroup_stream_object( + &cache_key, + group_id, + subgroup_id, + subgroup_stream_object, + duration, + ) .await; } - let object_id = 7; - let expected_cache_id = 7; - let expected_object_payload = vec![7, 8, 9, 10]; - let expected_object = - track_stream::Object::new(group_id, object_id, object_status, expected_object_payload) - .unwrap(); + let object_id = 9; + let expected_cache_id = 9; + let expected_object_payload = vec![9, 10, 11, 12]; + let expected_object = subgroup_stream::Object::new( + object_id, + extension_headers, + object_status, + expected_object_payload, + ) + .unwrap(); let result = object_cache_storage - .get_absolute_track_stream_object(&cache_key, group_id, object_id) + .get_absolute_subgroup_stream_object(&cache_key, group_id, subgroup_id, object_id) .await; assert!(result.is_ok()); @@ -972,79 +806,14 @@ mod success { } #[tokio::test] - async fn get_absolute_subgroup_stream_object() { - let session_id = 0; - let subscribe_id = 1; - let cache_key = CacheKey::new(session_id, subscribe_id); - let track_alias = 3; - let group_id = 4; - let subgroup_id = 5; - let publisher_priority = 6; - let object_status = None; - let duration = 1000; - let header = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, - ) - .unwrap(); - - // start object cache storage thread - let (cache_tx, mut cache_rx) = mpsc::channel::(1024); - tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); - let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - - let _ = object_cache_storage - .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, header) - .await; - - for i in 0..10 { - let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; - let object_id = i as u64; - - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); - - let _ = object_cache_storage - .set_subgroup_stream_object( - &cache_key, - group_id, - subgroup_id, - subgroup_stream_object, - duration, - ) - .await; - } - - let object_id = 9; - let expected_cache_id = 9; - let expected_object_payload = vec![9, 10, 11, 12]; - let expected_object = - subgroup_stream::Object::new(object_id, object_status, expected_object_payload) - .unwrap(); - - let result = object_cache_storage - .get_absolute_subgroup_stream_object(&cache_key, group_id, subgroup_id, object_id) - .await; - - assert!(result.is_ok()); - - let (result_cache_id, result_object) = result.unwrap().unwrap(); - assert_eq!(result_cache_id, expected_cache_id); - assert_eq!(result_object, expected_object); - } - - #[tokio::test] - async fn get_next_datagram_object() { + async fn get_next_datagram_object() { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; let group_id = 4; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; // start object cache storage thread @@ -1059,15 +828,15 @@ mod success { let object_id = i as u64; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers.clone(), object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); let _ = object_cache_storage .set_datagram_object(&cache_key, datagram_object, duration) @@ -1079,15 +848,15 @@ mod success { let expected_cache_id = 3; let expected_object_payload = vec![3, 4, 5, 6]; let expected_object = datagram::Object::new( - subscribe_id, track_alias, group_id, expected_object_id, publisher_priority, - object_status, + extension_headers, expected_object_payload, ) .unwrap(); + let expected_object = DatagramObject::ObjectDatagram(expected_object); let result = object_cache_storage .get_next_datagram_object(&cache_key, cache_id) @@ -1101,54 +870,60 @@ mod success { } #[tokio::test] - async fn get_next_track_stream_object() { + async fn get_next_datagram_object_status() { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; let group_id = 4; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; - let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) - .await; + let _ = object_cache_storage.create_datagram_cache(&cache_key).await; for i in 0..10 { - let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; + let object_status = ObjectStatus::DoesNotExist; let object_id = i as u64; - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); + let datagram_object = datagram_status::Object::new( + track_alias, + group_id, + object_id, + publisher_priority, + extension_headers.clone(), + object_status, + ) + .unwrap(); + let datagram_object = DatagramObject::ObjectDatagramStatus(datagram_object); let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) + .set_datagram_object(&cache_key, datagram_object, duration) .await; } - let cache_id = 4; - let expected_object_id = 5; - let expected_cache_id = 5; - let expected_object_payload = vec![5, 6, 7, 8]; - let expected_object = track_stream::Object::new( + let cache_id = 2; + let expected_object_id = 3; + let expected_cache_id = 3; + let expected_object_status = ObjectStatus::DoesNotExist; + let expected_object = datagram_status::Object::new( + track_alias, group_id, expected_object_id, - object_status, - expected_object_payload, + publisher_priority, + extension_headers, + expected_object_status, ) .unwrap(); + let expected_object = DatagramObject::ObjectDatagramStatus(expected_object); let result = object_cache_storage - .get_next_track_stream_object(&cache_key, cache_id) + .get_next_datagram_object(&cache_key, cache_id) .await; assert!(result.is_ok()); @@ -1167,16 +942,12 @@ mod success { let group_id = 4; let subgroup_id = 5; let publisher_priority = 6; + let extension_headers = vec![]; let object_status = None; let duration = 1000; - let header = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, - ) - .unwrap(); + let header = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -1191,8 +962,13 @@ mod success { let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; let object_id = i as u64; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers.clone(), + object_status, + object_payload, + ) + .unwrap(); let _ = object_cache_storage .set_subgroup_stream_object( @@ -1211,6 +987,7 @@ mod success { let expected_object_payload = vec![1, 2, 3, 4]; let expected_object = subgroup_stream::Object::new( expected_object_id, + extension_headers, object_status, expected_object_payload, ) @@ -1235,7 +1012,7 @@ mod success { let track_alias = 3; let group_id = 4; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; // start object cache storage thread @@ -1250,15 +1027,15 @@ mod success { let object_id = i as u64; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers.clone(), object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); let _ = object_cache_storage .set_datagram_object(&cache_key, datagram_object, duration) @@ -1269,15 +1046,15 @@ mod success { let expected_cache_id = 5; let expected_object_payload = vec![5, 6, 7, 8]; let expected_object = datagram::Object::new( - subscribe_id, track_alias, group_id, expected_object_id, publisher_priority, - object_status, + extension_headers, expected_object_payload, ) .unwrap(); + let expected_object = DatagramObject::ObjectDatagram(expected_object); let result = object_cache_storage .get_latest_datagram_object(&cache_key) @@ -1290,63 +1067,6 @@ mod success { assert_eq!(result_object, expected_object); } - #[tokio::test] - async fn get_latest_track_stream_object() { - let session_id = 0; - let subscribe_id = 1; - let cache_key = CacheKey::new(session_id, subscribe_id); - let track_alias = 3; - let group_id = 4; - let publisher_priority = 5; - let object_status = None; - let duration = 1000; - let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - // start object cache storage thread - let (cache_tx, mut cache_rx) = mpsc::channel::(1024); - tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); - let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - - let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) - .await; - - for i in 0..13 { - let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; - let object_id = i as u64; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) - .await; - } - - let expected_object_id = 12; - let expected_cache_id = 12; - let expected_object_payload = vec![12, 13, 14, 15]; - let expected_object = track_stream::Object::new( - group_id, - expected_object_id, - object_status, - expected_object_payload, - ) - .unwrap(); - - let result = object_cache_storage - .get_latest_track_stream_object(&cache_key) - .await; - - assert!(result.is_ok()); - - let (result_cache_id, result_object) = result.unwrap().unwrap(); - assert_eq!(result_cache_id, expected_cache_id); - assert_eq!(result_object, expected_object); - } - #[tokio::test] async fn get_latest_group_ascending_datagram_object() { let session_id = 0; @@ -1354,7 +1074,7 @@ mod success { let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; // start object cache storage thread @@ -1378,15 +1098,15 @@ mod success { let object_id = i as u64; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers.clone(), object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); let _ = object_cache_storage .set_datagram_object(&cache_key, datagram_object, duration) @@ -1398,15 +1118,15 @@ mod success { let expected_group_id = 3; let expected_object_payload = vec![21, 22, 23, 24]; let expected_object = datagram::Object::new( - subscribe_id, track_alias, expected_group_id, expected_object_id, publisher_priority, - object_status, + extension_headers, expected_object_payload, ) .unwrap(); + let expected_object = DatagramObject::ObjectDatagram(expected_object); let expected_cache_id = group_size * expected_group_id as u8 + expected_object_id as u8; let result = object_cache_storage @@ -1427,7 +1147,7 @@ mod success { let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; // start object cache storage thread @@ -1451,15 +1171,15 @@ mod success { let object_id = i as u64; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers.clone(), object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); let _ = object_cache_storage .set_datagram_object(&cache_key, datagram_object, duration) @@ -1472,15 +1192,15 @@ mod success { let expected_cache_id = 49; let expected_object_payload = vec![14, 15, 16, 17]; let expected_object = datagram::Object::new( - subscribe_id, track_alias, expected_group_id, expected_object_id, publisher_priority, - object_status, + extension_headers.clone(), expected_object_payload, ) .unwrap(); + let expected_object = DatagramObject::ObjectDatagram(expected_object); let result = object_cache_storage .get_latest_datagram_group(&cache_key) @@ -1494,83 +1214,20 @@ mod success { } #[tokio::test] - async fn get_latest_group_ascending_track_stream() { - let session_id = 0; - let subscribe_id = 1; - let cache_key = CacheKey::new(session_id, subscribe_id); - let track_alias = 3; - let publisher_priority = 5; - let object_status = None; - let duration = 1000; - let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - // start object cache storage thread - let (cache_tx, mut cache_rx) = mpsc::channel::(1024); - tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); - let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - - let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) - .await; - - let group_size = 12; - for j in 0..8 { - let group_id = j as u64; - - for i in 0..group_size { - let object_payload: Vec = vec![ - j * group_size + i, - j * group_size + i + 1, - j * group_size + i + 2, - j * group_size + i + 3, - ]; - let object_id = i as u64; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) - .await; - } - } - - let expected_object_id = 0; - let expected_group_id = 7; - let expected_object_payload = vec![84, 85, 86, 87]; - let expected_object = track_stream::Object::new( - expected_group_id, - expected_object_id, - object_status, - expected_object_payload, - ) - .unwrap(); - let expected_cache_id = group_size * expected_group_id as u8 + expected_object_id as u8; - - let result = object_cache_storage - .get_latest_track_stream_group(&cache_key) - .await; - - assert!(result.is_ok()); - - let (result_cache_id, result_object) = result.unwrap().unwrap(); - assert_eq!(result_cache_id, expected_cache_id as usize); - assert_eq!(result_object, expected_object); - } - - #[tokio::test] - async fn get_latest_group_descending_track_stream() { + async fn get_first_subgroup_stream_object() { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; - let publisher_priority = 5; + let group_id = 4; + let subgroup_id = 5; + let publisher_priority = 6; + let extension_headers = vec![]; let object_status = None; let duration = 1000; let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -1578,46 +1235,45 @@ mod success { let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) + .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, header) .await; - let group_size = 12; - for j in (5..9).rev() { - let group_id = j as u64; - - for i in 0..group_size { - let object_payload: Vec = vec![ - j * group_size + i, - j * group_size + i + 1, - j * group_size + i + 2, - j * group_size + i + 3, - ]; - let object_id = i as u64; + for i in 0..20 { + let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; + let object_id = i as u64; - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers.clone(), + object_status, + object_payload, + ) + .unwrap(); - let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) - .await; - } + let _ = object_cache_storage + .set_subgroup_stream_object( + &cache_key, + group_id, + subgroup_id, + subgroup_stream_object, + duration, + ) + .await; } let expected_object_id = 0; - let expected_group_id = 5; - let expected_cache_id = 36; - let expected_object_payload = vec![60, 61, 62, 63]; - let expected_object = track_stream::Object::new( - expected_group_id, + let expected_cache_id = 0; + let expected_object_payload = vec![0, 1, 2, 3]; + let expected_object = subgroup_stream::Object::new( expected_object_id, + extension_headers, object_status, expected_object_payload, ) .unwrap(); let result = object_cache_storage - .get_latest_track_stream_group(&cache_key) + .get_first_subgroup_stream_object(&cache_key, group_id, subgroup_id) .await; assert!(result.is_ok()); @@ -1628,7 +1284,7 @@ mod success { } #[tokio::test] - async fn get_first_subgroup_stream_object() { + async fn get_latest_subgroup_stream_object() { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); @@ -1636,16 +1292,12 @@ mod success { let group_id = 4; let subgroup_id = 5; let publisher_priority = 6; + let extension_headers = vec![]; let object_status = None; let duration = 1000; - let header = subgroup_stream::Header::new( - subscribe_id, - track_alias, - group_id, - subgroup_id, - publisher_priority, - ) - .unwrap(); + let header = + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -1660,8 +1312,13 @@ mod success { let object_payload: Vec = vec![i, i + 1, i + 2, i + 3]; let object_id = i as u64; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers.clone(), + object_status, + object_payload, + ) + .unwrap(); let _ = object_cache_storage .set_subgroup_stream_object( @@ -1674,18 +1331,19 @@ mod success { .await; } - let expected_object_id = 0; - let expected_cache_id = 0; - let expected_object_payload = vec![0, 1, 2, 3]; + let expected_object_id = 19; + let expected_cache_id = 19; + let expected_object_payload = vec![19, 20, 21, 22]; let expected_object = subgroup_stream::Object::new( expected_object_id, + extension_headers, object_status, expected_object_payload, ) .unwrap(); let result = object_cache_storage - .get_first_subgroup_stream_object(&cache_key, group_id, subgroup_id) + .get_latest_subgroup_stream_object(&cache_key, group_id, subgroup_id) .await; assert!(result.is_ok()); @@ -1703,6 +1361,7 @@ mod success { let track_alias = 3; let group_id = 4; let publisher_priority = 6; + let extension_headers = vec![]; let object_status = None; let duration = 1000; @@ -1715,7 +1374,6 @@ mod success { let subgroup_id = i as u64; let header = subgroup_stream::Header::new( - subscribe_id, track_alias, group_id, subgroup_id, @@ -1727,8 +1385,13 @@ mod success { .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, header) .await; - let subgroup_stream_object = - subgroup_stream::Object::new(subgroup_id, object_status, vec![]).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + subgroup_id, + extension_headers.clone(), + object_status, + vec![], + ) + .unwrap(); let _ = object_cache_storage .set_subgroup_stream_object( @@ -1760,7 +1423,7 @@ mod success { let cache_key = CacheKey::new(session_id, subscribe_id); let track_alias = 3; let publisher_priority = 5; - let object_status = None; + let extension_headers = vec![]; let duration = 1000; // start object cache storage thread @@ -1784,15 +1447,15 @@ mod success { let object_id = i as u64; let datagram_object = datagram::Object::new( - subscribe_id, track_alias, group_id, object_id, publisher_priority, - object_status, + extension_headers.clone(), object_payload, ) .unwrap(); + let datagram_object = DatagramObject::ObjectDatagram(datagram_object); let _ = object_cache_storage .set_datagram_object(&cache_key, datagram_object, duration) @@ -1807,76 +1470,14 @@ mod success { assert!(group_result.is_ok()); - let largest_group_id = group_result.unwrap(); + let largest_group_id = group_result.unwrap().unwrap(); assert_eq!(largest_group_id, expected_group_id); let object_result = object_cache_storage.get_largest_object_id(&cache_key).await; assert!(object_result.is_ok()); - let largest_object = object_result.unwrap(); - assert_eq!(largest_object, expected_object_id); - } - - #[tokio::test] - async fn get_largest_group_id_and_object_id_track() { - let session_id = 0; - let subscribe_id = 1; - let cache_key = CacheKey::new(session_id, subscribe_id); - let track_alias = 3; - let publisher_priority = 5; - let object_status = None; - let duration = 1000; - let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); - - // start object cache storage thread - let (cache_tx, mut cache_rx) = mpsc::channel::(1024); - tokio::spawn(async move { object_cache_storage(&mut cache_rx).await }); - let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); - - let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header) - .await; - - for j in 0..8 { - let group_id = j as u64; - let group_size = 12; - - for i in 0..group_size { - let object_payload: Vec = vec![ - j * group_size + i, - j * group_size + i + 1, - j * group_size + i + 2, - j * group_size + i + 3, - ]; - let object_id = i as u64; - - let track_stream_object = - track_stream::Object::new(group_id, object_id, object_status, object_payload) - .unwrap(); - - let _ = object_cache_storage - .set_track_stream_object(&cache_key, track_stream_object, duration) - .await; - } - } - - let expected_object_id = 11; - let expected_group_id = 7; - - let group_result = object_cache_storage.get_largest_group_id(&cache_key).await; - - assert!(group_result.is_ok()); - - let largest_group_id = group_result.unwrap(); - assert_eq!(largest_group_id, expected_group_id); - - let object_result = object_cache_storage.get_largest_object_id(&cache_key).await; - - assert!(object_result.is_ok()); - - let largest_object = object_result.unwrap(); + let largest_object = object_result.unwrap().unwrap(); assert_eq!(largest_object, expected_object_id); } @@ -1889,10 +1490,10 @@ mod success { let group_id = 4; let subgroup_id = 5; let publisher_priority = 6; + let extension_headers = vec![]; let object_status = None; let duration = 1000; let header = subgroup_stream::Header::new( - subscribe_id, track_alias, group_id, // Group ID is fixed subgroup_id, @@ -1921,8 +1522,13 @@ mod success { ]; let object_id = i as u64; - let subgroup_stream_object = - subgroup_stream::Object::new(object_id, object_status, object_payload).unwrap(); + let subgroup_stream_object = subgroup_stream::Object::new( + object_id, + extension_headers.clone(), + object_status, + object_payload, + ) + .unwrap(); let _ = object_cache_storage .set_subgroup_stream_object( @@ -1943,14 +1549,14 @@ mod success { assert!(group_result.is_ok()); - let largest_group_id = group_result.unwrap(); + let largest_group_id = group_result.unwrap().unwrap(); assert_eq!(largest_group_id, expected_group_id); let object_result = object_cache_storage.get_largest_object_id(&cache_key).await; assert!(object_result.is_ok()); - let largest_object = object_result.unwrap(); + let largest_object = object_result.unwrap().unwrap(); assert_eq!(largest_object, expected_object_id); } @@ -1959,10 +1565,13 @@ mod success { let session_id = 0; let subscribe_id = 1; let cache_key = CacheKey::new(session_id, subscribe_id); + let group_id = 4; + let subgroup_id = 5; let track_alias = 3; let publisher_priority = 6; let header = - track_stream::Header::new(subscribe_id, track_alias, publisher_priority).unwrap(); + subgroup_stream::Header::new(track_alias, group_id, subgroup_id, publisher_priority) + .unwrap(); // start object cache storage thread let (cache_tx, mut cache_rx) = mpsc::channel::(1024); @@ -1971,7 +1580,7 @@ mod success { let mut object_cache_storage = ObjectCacheStorageWrapper::new(cache_tx); let _ = object_cache_storage - .create_track_stream_cache(&cache_key, header.clone()) + .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, header.clone()) .await; let delete_result = object_cache_storage.delete_client(session_id).await; @@ -1979,7 +1588,7 @@ mod success { assert!(delete_result.is_ok()); let get_result = object_cache_storage - .get_track_stream_header(&cache_key) + .get_subgroup_stream_header(&cache_key, group_id, subgroup_id) .await; assert!(get_result.is_err()); diff --git a/moqt-server/src/modules/pubsub_relation_manager/commands.rs b/moqt-server/src/modules/pubsub_relation_manager/commands.rs index e0614d81..27d85ecd 100644 --- a/moqt-server/src/modules/pubsub_relation_manager/commands.rs +++ b/moqt-server/src/modules/pubsub_relation_manager/commands.rs @@ -2,8 +2,11 @@ use anyhow::Result; use tokio::sync::oneshot; use moqt_core::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, + models::{ + range::{ObjectRange, ObjectStart}, + tracks::ForwardingPreference, + }, }; #[cfg(test)] @@ -54,26 +57,12 @@ pub(crate) enum PubSubRelationCommand { downstream_session_id: usize, resp: oneshot::Sender>, }, - IsTrackExisting { + IsUpstreamSubscribed { track_namespace: Vec, track_name: String, resp: oneshot::Sender>, }, - GetUpstreamSubscriptionByFullTrackName { - track_namespace: Vec, - track_name: String, - resp: oneshot::Sender>>, - }, - GetUpstreamSubscriptionBySessionIdAndSubscribeId { - upstream_session_id: usize, - upstream_subscribe_id: u64, - resp: oneshot::Sender>>, - }, - GetDownstreamSubscriptionBySessionIdAndSubscribeId { - downstream_session_id: usize, - downstream_subscribe_id: u64, - resp: oneshot::Sender>>, - }, + // TODO: Unify getter methods of subscribe_id GetUpstreamSessionId { track_namespace: Vec, resp: oneshot::Sender>>, @@ -90,6 +79,16 @@ pub(crate) enum PubSubRelationCommand { upstream_session_id: usize, resp: oneshot::Sender>>, }, + GetDownstreamTrackAlias { + downstream_session_id: usize, + downstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + GetUpstreamSubscribeIdByTrackAlias { + upstream_session_id: usize, + upstream_track_alias: u64, + resp: oneshot::Sender>>, + }, SetDownstreamSubscription { downstream_session_id: usize, subscribe_id: u64, @@ -102,7 +101,6 @@ pub(crate) enum PubSubRelationCommand { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, resp: oneshot::Sender>, }, SetUpstreamSubscription { @@ -115,7 +113,6 @@ pub(crate) enum PubSubRelationCommand { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, resp: oneshot::Sender>, }, SetPubSubRelation { @@ -191,6 +188,97 @@ pub(crate) enum PubSubRelationCommand { upstream_subscribe_id: u64, resp: oneshot::Sender>>, }, + GetUpstreamFilterType { + upstream_session_id: usize, + upstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + GetDownstreamFilterType { + downstream_session_id: usize, + downstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + GetUpstreamRequestedObjectRange { + upstream_session_id: usize, + upstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + GetDownstreamRequestedObjectRange { + downstream_session_id: usize, + downstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + SetDownstreamActualObjectStart { + downstream_session_id: usize, + downstream_subscribe_id: u64, + actual_object_start: ObjectStart, + resp: oneshot::Sender>, + }, + GetDownstreamActualObjectStart { + downstream_session_id: usize, + downstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + SetUpstreamStreamId { + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + resp: oneshot::Sender>, + }, + GetUpstreamSubscribeIdsForClient { + upstream_session_id: usize, + resp: oneshot::Sender>>, + }, + GetUpstreamGroupIdsForSubscription { + upstream_session_id: usize, + upstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + GetUpstreamSubgroupIdsForGroup { + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + resp: oneshot::Sender>>, + }, + GetUpstreamStreamIdForSubgroup { + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + resp: oneshot::Sender>>, + }, + SetDownstreamStreamId { + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + resp: oneshot::Sender>, + }, + GetDownstreamSubscribeIdsForClient { + downstream_session_id: usize, + resp: oneshot::Sender>>, + }, + GetDownstreamGroupIdsForSubscription { + downstream_session_id: usize, + downstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, + GetDownstreamSubgroupIdsForGroup { + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + resp: oneshot::Sender>>, + }, + GetDownstreamStreamIdForSubgroup { + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + resp: oneshot::Sender>>, + }, GetRelatedSubscribers { upstream_session_id: usize, upstream_subscribe_id: u64, diff --git a/moqt-server/src/modules/pubsub_relation_manager/manager.rs b/moqt-server/src/modules/pubsub_relation_manager/manager.rs index 66d6e4fe..1071f9de 100644 --- a/moqt-server/src/modules/pubsub_relation_manager/manager.rs +++ b/moqt-server/src/modules/pubsub_relation_manager/manager.rs @@ -304,52 +304,55 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver { - let consumer = consumers.iter().find(|(_, consumer)| { - consumer.has_track(track_namespace.clone(), track_name.clone()) - }); - let is_existing = consumer.is_some(); - resp.send(Ok(is_existing)).unwrap(); - } - GetUpstreamSubscriptionByFullTrackName { - track_namespace, - track_name, - resp, - } => { - let consumer = consumers.iter().find(|(_, consumer)| { - consumer.has_track(track_namespace.clone(), track_name.clone()) - }); - let result = consumer - .map(|(_, consumer)| { - consumer.get_subscription_by_full_track_name(track_namespace, track_name) - }) - .unwrap(); - - resp.send(result).unwrap(); - } - GetUpstreamSubscriptionBySessionIdAndSubscribeId { + GetUpstreamSubscribeIdByTrackAlias { upstream_session_id, - upstream_subscribe_id, + upstream_track_alias, resp, } => { - let consumer = consumers.get(&upstream_session_id).unwrap(); - let result = consumer.get_subscription(upstream_subscribe_id); + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let result = consumer.get_subscribe_id_by_track_alias(upstream_track_alias); resp.send(result).unwrap(); } - GetDownstreamSubscriptionBySessionIdAndSubscribeId { + GetDownstreamTrackAlias { downstream_session_id, downstream_subscribe_id, resp, } => { - let producer = producers.get(&downstream_session_id).unwrap(); - let result = producer.get_subscription(downstream_subscribe_id); + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; - resp.send(result).unwrap(); + let track_alias = producer.get_track_alias(downstream_subscribe_id).unwrap(); + resp.send(Ok(track_alias)).unwrap(); + } + IsUpstreamSubscribed { + track_namespace, + track_name, + resp, + } => { + let consumer = consumers.iter().find(|(_, consumer)| { + consumer.has_track(track_namespace.clone(), track_name.clone()) + }); + let is_existing = consumer.is_some(); + resp.send(Ok(is_existing)).unwrap(); } SetDownstreamSubscription { downstream_session_id, @@ -363,7 +366,6 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver { // Return an error if the subscriber does not exist @@ -388,7 +390,6 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver resp.send(Ok(())).unwrap(), Err(err) => { @@ -408,7 +409,6 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver { // Return an error if the publisher does not exist @@ -445,7 +445,6 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver resp.send(Ok((subscribe_id, track_alias))).unwrap(), Err(err) => { @@ -794,6 +793,132 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver { + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let filter_type = consumer.get_filter_type(upstream_subscribe_id).unwrap(); + resp.send(Ok(filter_type)).unwrap(); + } + GetDownstreamFilterType { + downstream_session_id, + downstream_subscribe_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let filter_type = producer.get_filter_type(downstream_subscribe_id).unwrap(); + resp.send(Ok(filter_type)).unwrap(); + } + GetUpstreamRequestedObjectRange { + upstream_session_id, + upstream_subscribe_id, + resp, + } => { + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let range = consumer + .get_requested_object_range(upstream_subscribe_id) + .unwrap(); + resp.send(Ok(range)).unwrap(); + } + GetDownstreamRequestedObjectRange { + downstream_session_id, + downstream_subscribe_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let range = producer + .get_requested_object_range(downstream_subscribe_id) + .unwrap(); + resp.send(Ok(range)).unwrap(); + } + SetDownstreamActualObjectStart { + downstream_session_id, + downstream_subscribe_id, + actual_object_start, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get_mut(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + match producer.set_actual_object_start(downstream_subscribe_id, actual_object_start) + { + Ok(_) => resp.send(Ok(())).unwrap(), + Err(err) => { + tracing::error!("set_actual_object_start: err: {:?}", err.to_string()); + resp.send(Err(anyhow!(err))).unwrap(); + } + } + } + GetDownstreamActualObjectStart { + downstream_session_id, + downstream_subscribe_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let actual_object_start = producer + .get_actual_object_start(downstream_subscribe_id) + .unwrap(); + resp.send(Ok(actual_object_start)).unwrap(); + } GetRelatedSubscribers { upstream_session_id, upstream_subscribe_id, @@ -809,6 +934,234 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver { + // Return an error if the publisher does not exist + let consumer = match consumers.get_mut(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + match consumer.set_stream_id( + upstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + ) { + Ok(_) => resp.send(Ok(())).unwrap(), + Err(err) => { + tracing::error!("set_stream_id: err: {:?}", err.to_string()); + resp.send(Err(anyhow!(err))).unwrap(); + } + } + } + GetUpstreamSubscribeIdsForClient { + upstream_session_id, + resp, + } => { + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + let subscribe_ids = consumer.get_all_subscribe_ids().unwrap(); + resp.send(Ok(subscribe_ids)).unwrap(); + } + GetUpstreamGroupIdsForSubscription { + upstream_session_id, + upstream_subscribe_id, + resp, + } => { + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + let group_ids = consumer + .get_group_ids_for_subscription(upstream_subscribe_id) + .unwrap(); + resp.send(Ok(group_ids)).unwrap(); + } + GetUpstreamSubgroupIdsForGroup { + upstream_session_id, + upstream_subscribe_id, + group_id, + resp, + } => { + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let subgroup_ids = consumer + .get_subgroup_ids_for_group(upstream_subscribe_id, group_id) + .unwrap(); + resp.send(Ok(subgroup_ids)).unwrap(); + } + GetUpstreamStreamIdForSubgroup { + upstream_session_id, + upstream_subscribe_id, + group_id, + subgroup_id, + resp, + } => { + // Return an error if the publisher does not exist + let consumer = match consumers.get(&upstream_session_id) { + Some(consumer) => consumer, + None => { + let msg = "publisher not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let stream_id = consumer + .get_stream_id_for_subgroup(upstream_subscribe_id, group_id, subgroup_id) + .unwrap(); + resp.send(Ok(stream_id)).unwrap(); + } + SetDownstreamStreamId { + downstream_session_id, + downstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get_mut(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + match producer.set_stream_id( + downstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + ) { + Ok(_) => resp.send(Ok(())).unwrap(), + Err(err) => { + tracing::error!("set_stream_id: err: {:?}", err.to_string()); + resp.send(Err(anyhow!(err))).unwrap(); + } + } + } + GetDownstreamSubscribeIdsForClient { + downstream_session_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + let subscribe_ids = producer.get_all_subscribe_ids().unwrap(); + resp.send(Ok(subscribe_ids)).unwrap(); + } + GetDownstreamGroupIdsForSubscription { + downstream_session_id, + downstream_subscribe_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + let group_ids = producer + .get_group_ids_for_subscription(downstream_subscribe_id) + .unwrap(); + resp.send(Ok(group_ids)).unwrap(); + } + GetDownstreamSubgroupIdsForGroup { + downstream_session_id, + downstream_subscribe_id, + group_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let stream_ids = producer + .get_subgroup_ids_for_group(downstream_subscribe_id, group_id) + .unwrap(); + resp.send(Ok(stream_ids)).unwrap(); + } + GetDownstreamStreamIdForSubgroup { + downstream_session_id, + downstream_subscribe_id, + group_id, + subgroup_id, + resp, + } => { + // Return an error if the subscriber does not exist + let producer = match producers.get(&downstream_session_id) { + Some(producer) => producer, + None => { + let msg = "subscriber not found"; + tracing::error!(msg); + resp.send(Err(anyhow!(msg))).unwrap(); + continue; + } + }; + + let stream_id = producer + .get_stream_id_for_subgroup(downstream_subscribe_id, group_id, subgroup_id) + .unwrap(); + resp.send(Ok(stream_id)).unwrap(); + } GetRelatedPublisher { downstream_session_id, downstream_subscribe_id, diff --git a/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs b/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs index e510115e..2ce4830d 100644 --- a/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs +++ b/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs @@ -3,8 +3,11 @@ use async_trait::async_trait; use tokio::sync::{mpsc, oneshot}; use moqt_core::{ - messages::control_messages::subscribe::{FilterType, GroupOrder}, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, + messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}, + models::{ + range::{ObjectRange, ObjectStart}, + tracks::ForwardingPreference, + }, pubsub_relation_manager_repository::PubSubRelationManagerRepository, }; @@ -47,6 +50,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn set_upstream_announced_namespace( &self, track_namespace: Vec, @@ -68,6 +72,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn set_downstream_announced_namespace( &self, track_namespace: Vec, @@ -89,6 +94,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn set_downstream_subscribed_namespace_prefix( &self, track_namespace_prefix: Vec, @@ -110,6 +116,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn setup_subscriber( &self, max_subscribe_id: u64, @@ -131,6 +138,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn is_downstream_subscribe_id_unique( &self, subscribe_id: u64, @@ -151,6 +159,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn is_downstream_subscribe_id_less_than_max( &self, subscribe_id: u64, @@ -171,6 +180,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn is_downstream_track_alias_unique( &self, track_alias: u64, @@ -191,13 +201,14 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } - async fn is_track_existing( + + async fn is_upstream_subscribed( &self, track_namespace: Vec, track_name: String, ) -> Result { let (resp_tx, resp_rx) = oneshot::channel::>(); - let cmd = PubSubRelationCommand::IsTrackExisting { + let cmd = PubSubRelationCommand::IsUpstreamSubscribed { track_namespace, track_name, resp: resp_tx, @@ -211,15 +222,11 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } - async fn get_upstream_subscription_by_full_track_name( - &self, - track_namespace: Vec, - track_name: String, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = PubSubRelationCommand::GetUpstreamSubscriptionByFullTrackName { + + async fn get_upstream_session_id(&self, track_namespace: Vec) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamSessionId { track_namespace, - track_name, resp: resp_tx, }; self.tx.send(cmd).await.unwrap(); @@ -227,37 +234,18 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { let result = resp_rx.await.unwrap(); match result { - Ok(subscription) => Ok(subscription), + Ok(upstream_session_id) => Ok(upstream_session_id), Err(err) => bail!(err), } } - async fn get_upstream_subscription_by_ids( - &self, - upstream_session_id: usize, - upstream_subscribe_id: u64, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = PubSubRelationCommand::GetUpstreamSubscriptionBySessionIdAndSubscribeId { - upstream_session_id, - upstream_subscribe_id, - resp: resp_tx, - }; - self.tx.send(cmd).await.unwrap(); - - let result = resp_rx.await.unwrap(); - match result { - Ok(subscription) => Ok(subscription), - Err(err) => bail!(err), - } - } - async fn get_downstream_subscription_by_ids( + async fn get_downstream_track_alias( &self, downstream_session_id: usize, downstream_subscribe_id: u64, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = PubSubRelationCommand::GetDownstreamSubscriptionBySessionIdAndSubscribeId { + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamTrackAlias { downstream_session_id, downstream_subscribe_id, resp: resp_tx, @@ -267,25 +255,11 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { let result = resp_rx.await.unwrap(); match result { - Ok(subscription) => Ok(subscription), + Ok(track_alias) => Ok(track_alias), Err(err) => bail!(err), } } - async fn get_upstream_session_id(&self, track_namespace: Vec) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = PubSubRelationCommand::GetUpstreamSessionId { - track_namespace, - resp: resp_tx, - }; - self.tx.send(cmd).await.unwrap(); - let result = resp_rx.await.unwrap(); - - match result { - Ok(upstream_session_id) => Ok(upstream_session_id), - Err(err) => bail!(err), - } - } async fn get_requesting_downstream_session_ids_and_subscribe_ids( &self, upstream_subscribe_id: u64, @@ -306,6 +280,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn get_upstream_subscribe_id( &self, track_namespace: Vec, @@ -328,6 +303,28 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + + async fn get_upstream_subscribe_id_by_track_alias( + &self, + upstream_session_id: usize, + upstream_track_alias: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamSubscribeIdByTrackAlias { + upstream_session_id, + upstream_track_alias, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(subscribe_id) => Ok(subscribe_id), + Err(err) => bail!(err), + } + } + async fn set_downstream_subscription( &self, downstream_session_id: usize, @@ -341,7 +338,6 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<()> { let (resp_tx, resp_rx) = oneshot::channel::>(); let cmd = PubSubRelationCommand::SetDownstreamSubscription { @@ -356,7 +352,6 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { start_group, start_object, end_group, - end_object, resp: resp_tx, }; self.tx.send(cmd).await.unwrap(); @@ -368,6 +363,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + #[allow(clippy::too_many_arguments)] async fn set_upstream_subscription( &self, @@ -380,7 +376,6 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { start_group: Option, start_object: Option, end_group: Option, - end_object: Option, ) -> Result<(u64, u64)> { let (resp_tx, resp_rx) = oneshot::channel::>(); let cmd = PubSubRelationCommand::SetUpstreamSubscription { @@ -393,7 +388,6 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { start_group, start_object, end_group, - end_object, resp: resp_tx, }; self.tx.send(cmd).await.unwrap(); @@ -405,6 +399,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn set_pubsub_relation( &self, upstream_session_id: usize, @@ -429,6 +424,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn activate_downstream_subscription( &self, downstream_session_id: usize, @@ -448,6 +444,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn activate_upstream_subscription( &self, upstream_session_id: usize, @@ -467,6 +464,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn get_upstream_namespaces_matches_prefix( &self, track_namespace_prefix: Vec, @@ -484,6 +482,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn is_namespace_announced( &self, track_namespace: Vec, @@ -503,6 +502,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn get_downstream_session_ids_by_upstream_namespace( &self, track_namespace: Vec, @@ -520,6 +520,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn delete_upstream_announced_namespace( &self, track_namespace: Vec, @@ -539,6 +540,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn delete_client(&self, session_id: usize) -> Result { let (resp_tx, resp_rx) = oneshot::channel::>(); let cmd = DeleteClient { @@ -553,6 +555,7 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn delete_pubsub_relation( &self, upstream_session_id: usize, @@ -682,13 +685,13 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { } } - async fn get_related_subscribers( + async fn get_upstream_filter_type( &self, upstream_session_id: usize, upstream_subscribe_id: u64, - ) -> Result> { - let (resp_tx, resp_rx) = oneshot::channel::>>(); - let cmd = PubSubRelationCommand::GetRelatedSubscribers { + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamFilterType { upstream_session_id, upstream_subscribe_id, resp: resp_tx, @@ -698,18 +701,18 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { let result = resp_rx.await.unwrap(); match result { - Ok(related_subscribers) => Ok(related_subscribers), + Ok(filter_type) => Ok(filter_type), Err(err) => bail!(err), } } - async fn get_related_publisher( + async fn get_downstream_filter_type( &self, downstream_session_id: usize, downstream_subscribe_id: u64, - ) -> Result<(usize, u64)> { - let (resp_tx, resp_rx) = oneshot::channel::>(); - let cmd = PubSubRelationCommand::GetRelatedPublisher { + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamFilterType { downstream_session_id, downstream_subscribe_id, resp: resp_tx, @@ -719,137 +722,983 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { let result = resp_rx.await.unwrap(); match result { - Ok(related_publisher) => Ok(related_publisher), + Ok(filter_type) => Ok(filter_type), Err(err) => bail!(err), } } -} - -#[cfg(test)] -pub(crate) mod test_helper_fn { - use crate::modules::pubsub_relation_manager::{ - commands::PubSubRelationCommand, - manager::{Consumers, Producers}, - relation::PubSubRelation, - wrapper::PubSubRelationManagerWrapper, - }; - use anyhow::Result; - use tokio::sync::oneshot; + async fn get_upstream_requested_object_range( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamRequestedObjectRange { + upstream_session_id, + upstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); - pub(crate) async fn get_node_and_relation_clone( - pubsub_relation_manager: &PubSubRelationManagerWrapper, - ) -> (Consumers, Producers, PubSubRelation) { - let (resp_tx, resp_rx) = oneshot::channel::>(); - let cmd = PubSubRelationCommand::GetNodeAndRelationClone { resp: resp_tx }; - pubsub_relation_manager.tx.send(cmd).await.unwrap(); + let result = resp_rx.await.unwrap(); - resp_rx.await.unwrap().unwrap() + match result { + Ok(range) => Ok(range), + Err(err) => bail!(err), + } } -} -#[cfg(test)] -mod success { - use crate::modules::pubsub_relation_manager::{ - commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::test_helper_fn, - wrapper::PubSubRelationManagerWrapper, - }; - use moqt_core::messages::control_messages::subscribe::{FilterType, GroupOrder}; - use moqt_core::models::subscriptions::{ - nodes::registry::SubscriptionNodeRegistry, Subscription, - }; - use moqt_core::models::tracks::ForwardingPreference; - use moqt_core::pubsub_relation_manager_repository::PubSubRelationManagerRepository; - use tokio::sync::mpsc; + async fn get_downstream_requested_object_range( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamRequestedObjectRange { + downstream_session_id, + downstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); - #[tokio::test] - async fn setup_publisher() { - let max_subscribe_id = 10; - let upstream_session_id = 1; + let result = resp_rx.await.unwrap(); - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + match result { + Ok(range) => Ok(range), + Err(err) => bail!(err), + } + } - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let result = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) - .await; - assert!(result.is_ok()); + async fn set_downstream_actual_object_start( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + actual_object_start: ObjectStart, + ) -> Result<()> { + let (resp_tx, resp_rx) = oneshot::channel::>(); + let cmd = PubSubRelationCommand::SetDownstreamActualObjectStart { + downstream_session_id, + downstream_subscribe_id, + actual_object_start, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); - // Check if the publisher is created - let (consumers, _, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let length = consumers.len(); + let result = resp_rx.await.unwrap(); - assert_eq!(length, 1); + match result { + Ok(_) => Ok(()), + Err(err) => bail!(err), + } } - #[tokio::test] - async fn set_upstream_announced_namespace() { - let max_subscribe_id = 10; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let upstream_session_id = 1; + async fn get_downstream_actual_object_start( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamActualObjectStart { + downstream_session_id, + downstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + let result = resp_rx.await.unwrap(); - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) - .await; - let result = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; - assert!(result.is_ok()); + match result { + Ok(actual_object_start) => Ok(actual_object_start), + Err(err) => bail!(err), + } + } - // Check if the track_namespace is set - let (consumers, _, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + async fn set_upstream_stream_id( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()> { + let (resp_tx, resp_rx) = oneshot::channel::>(); + let cmd = PubSubRelationCommand::SetUpstreamStreamId { + upstream_session_id, + upstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); - let consumer = consumers.get(&upstream_session_id).unwrap(); - let announced_namespaces = consumer.get_namespaces().unwrap(); - let announced_namespace = announced_namespaces.first().unwrap().to_vec(); + let result = resp_rx.await.unwrap(); - assert_eq!(announced_namespace, track_namespace); + match result { + Ok(_) => Ok(()), + Err(err) => bail!(err), + } } - #[tokio::test] - async fn set_downstream_announced_namespace() { - let max_subscribe_id = 10; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let downstream_session_id = 1; + async fn get_upstream_subscribe_ids_for_client( + &self, + upstream_session_id: usize, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamSubscribeIdsForClient { + upstream_session_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + let result = resp_rx.await.unwrap(); - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) - .await; - let result = pubsub_relation_manager - .set_downstream_announced_namespace(track_namespace.clone(), downstream_session_id) - .await; - assert!(result.is_ok()); + match result { + Ok(subscribe_ids) => Ok(subscribe_ids), + Err(err) => bail!(err), + } + } - // Check if the track_namespace is set - let (_, producers, _) = + async fn get_upstream_group_ids_for_subscription( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamGroupIdsForSubscription { + upstream_session_id, + upstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(group_ids) => Ok(group_ids), + Err(err) => bail!(err), + } + } + + async fn get_upstream_subgroup_ids_for_group( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamSubgroupIdsForGroup { + upstream_session_id, + upstream_subscribe_id, + group_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(stream_ids) => Ok(stream_ids), + Err(err) => bail!(err), + } + } + + async fn get_upstream_stream_id_for_subgroup( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamStreamIdForSubgroup { + upstream_session_id, + upstream_subscribe_id, + group_id, + subgroup_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(stream_id) => Ok(stream_id), + Err(err) => bail!(err), + } + } + + async fn set_downstream_stream_id( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + stream_id: u64, + ) -> Result<()> { + let (resp_tx, resp_rx) = oneshot::channel::>(); + let cmd = PubSubRelationCommand::SetDownstreamStreamId { + downstream_session_id, + downstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(_) => Ok(()), + Err(err) => bail!(err), + } + } + + async fn get_downstream_subscribe_ids_for_client( + &self, + downstream_session_id: usize, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamSubscribeIdsForClient { + downstream_session_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(subscribe_ids) => Ok(subscribe_ids), + Err(err) => bail!(err), + } + } + + async fn get_downstream_group_ids_for_subscription( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamGroupIdsForSubscription { + downstream_session_id, + downstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(group_ids) => Ok(group_ids), + Err(err) => bail!(err), + } + } + + async fn get_downstream_stream_id_for_subgroup( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + subgroup_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamStreamIdForSubgroup { + downstream_session_id, + downstream_subscribe_id, + group_id, + subgroup_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(stream_id) => Ok(stream_id), + Err(err) => bail!(err), + } + } + + async fn get_downstream_subgroup_ids_for_group( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + group_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetDownstreamSubgroupIdsForGroup { + downstream_session_id, + downstream_subscribe_id, + group_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(stream_ids) => Ok(stream_ids), + Err(err) => bail!(err), + } + } + + async fn get_related_subscribers( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetRelatedSubscribers { + upstream_session_id, + upstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(related_subscribers) => Ok(related_subscribers), + Err(err) => bail!(err), + } + } + + async fn get_related_publisher( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result<(usize, u64)> { + let (resp_tx, resp_rx) = oneshot::channel::>(); + let cmd = PubSubRelationCommand::GetRelatedPublisher { + downstream_session_id, + downstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(related_publisher) => Ok(related_publisher), + Err(err) => bail!(err), + } + } +} + +#[cfg(test)] +pub(crate) mod test_helper_fn { + use crate::modules::pubsub_relation_manager::{ + commands::PubSubRelationCommand, + manager::{Consumers, Producers}, + relation::PubSubRelation, + wrapper::PubSubRelationManagerWrapper, + }; + use anyhow::Result; + + use tokio::sync::oneshot; + + pub(crate) async fn get_node_and_relation_clone( + pubsub_relation_manager: &PubSubRelationManagerWrapper, + ) -> (Consumers, Producers, PubSubRelation) { + let (resp_tx, resp_rx) = oneshot::channel::>(); + let cmd = PubSubRelationCommand::GetNodeAndRelationClone { resp: resp_tx }; + pubsub_relation_manager.tx.send(cmd).await.unwrap(); + + resp_rx.await.unwrap().unwrap() + } +} + +#[cfg(test)] +mod success { + use crate::modules::pubsub_relation_manager::{ + commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::test_helper_fn, + wrapper::PubSubRelationManagerWrapper, + }; + use moqt_core::messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}; + use moqt_core::models::range::ObjectStart; + use moqt_core::models::subscriptions::{ + nodes::registry::SubscriptionNodeRegistry, Subscription, + }; + use moqt_core::models::tracks::ForwardingPreference; + use moqt_core::pubsub_relation_manager_repository::PubSubRelationManagerRepository; + use tokio::sync::mpsc; + + #[tokio::test] + async fn setup_publisher() { + let max_subscribe_id = 10; + let upstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let result = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + assert!(result.is_ok()); + + // Check if the publisher is created + let (consumers, _, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let length = consumers.len(); + + assert_eq!(length, 1); + } + + #[tokio::test] + async fn set_upstream_announced_namespace() { + let max_subscribe_id = 10; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let upstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + let result = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + assert!(result.is_ok()); + + // Check if the track_namespace is set + let (consumers, _, _) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let producer = producers.get(&downstream_session_id).unwrap(); - let announced_namespaces = producer.get_namespaces().unwrap(); - let announced_namespace = announced_namespaces.first().unwrap().to_vec(); + let consumer = consumers.get(&upstream_session_id).unwrap(); + let announced_namespaces = consumer.get_namespaces().unwrap(); + let announced_namespace = announced_namespaces.first().unwrap().to_vec(); + + assert_eq!(announced_namespace, track_namespace); + } + + #[tokio::test] + async fn set_downstream_announced_namespace() { + let max_subscribe_id = 10; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + let result = pubsub_relation_manager + .set_downstream_announced_namespace(track_namespace.clone(), downstream_session_id) + .await; + assert!(result.is_ok()); + + // Check if the track_namespace is set + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + + let producer = producers.get(&downstream_session_id).unwrap(); + let announced_namespaces = producer.get_namespaces().unwrap(); + let announced_namespace = announced_namespaces.first().unwrap().to_vec(); + + assert_eq!(announced_namespace, track_namespace); + } + + #[tokio::test] + async fn set_downstream_subscribed_namespace_prefix() { + let max_subscribe_id = 10; + let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + let result = pubsub_relation_manager + .set_downstream_subscribed_namespace_prefix( + track_namespace_prefix.clone(), + downstream_session_id, + ) + .await; + assert!(result.is_ok()); + + // Check if the track_namespace_prefix is set + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + + let producer = producers.get(&downstream_session_id).unwrap(); + let subscribed_namespace_prefixes = producer.get_namespace_prefixes().unwrap(); + let subscribed_namespace_prefix = subscribed_namespace_prefixes.first().unwrap().to_vec(); + + assert_eq!(subscribed_namespace_prefix, track_namespace_prefix); + } + + #[tokio::test] + async fn setup_subscriber() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let result = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + assert!(result.is_ok()); + + // Check if the subscriber is created + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let length = producers.len(); + + assert_eq!(length, 1); + } + + #[tokio::test] + async fn is_downstream_subscribe_id_unique_true() { + let max_subscribe_id = 10; + let subscribe_id = 1; + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + + let result = pubsub_relation_manager + .is_downstream_subscribe_id_unique(subscribe_id, downstream_session_id) + .await; + + let is_unique = result.unwrap(); + assert!(is_unique); + } + + #[tokio::test] + async fn is_downstream_subscribe_id_unique_false() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + + let result = pubsub_relation_manager + .is_downstream_subscribe_id_unique(subscribe_id, downstream_session_id) + .await; + + let is_unique = result.unwrap(); + assert!(!is_unique); + } + + #[tokio::test] + async fn is_downstream_subscribe_id_less_than_max_true() { + let max_subscribe_id = 10; + let subscribe_id = 1; + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + + let result = pubsub_relation_manager + .is_downstream_subscribe_id_less_than_max(subscribe_id, downstream_session_id) + .await; + + let is_less = result.unwrap(); + assert!(is_less); + } + + #[tokio::test] + async fn is_downstream_subscribe_id_less_than_max_false() { + let max_subscribe_id = 10; + let subscribe_id = 11; + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + + let result = pubsub_relation_manager + .is_downstream_subscribe_id_less_than_max(subscribe_id, downstream_session_id) + .await; + + let is_less = result.unwrap(); + assert!(!is_less); + } + + #[tokio::test] + async fn is_downstream_track_alias_unique_true() { + let max_subscribe_id = 10; + let track_alias = 1; + let downstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + + let result = pubsub_relation_manager + .is_downstream_track_alias_unique(track_alias, downstream_session_id) + .await; + + let is_unique = result.unwrap(); + assert!(is_unique); + } + + #[tokio::test] + async fn is_unique_downstream_track_alias_false() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + + let result = pubsub_relation_manager + .is_downstream_track_alias_unique(track_alias, downstream_session_id) + .await; + assert!(result.is_ok()); + + let is_unique = result.unwrap(); + assert!(!is_unique); + } + + #[tokio::test] + async fn is_upstream_subscribed() { + let max_subscribe_id = 10; + let upstream_session_id = 1; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + + let _ = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + + let result = pubsub_relation_manager + .is_upstream_subscribed(track_namespace, track_name) + .await; + assert!(result.is_ok()); + + let is_existing = result.unwrap(); + assert!(is_existing); + } + + #[tokio::test] + async fn not_upstream_subscribed() { + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "test_name".to_string(); + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let result = pubsub_relation_manager + .is_upstream_subscribed(track_namespace, track_name) + .await; + assert!(result.is_ok()); + + let is_existing = result.unwrap(); + assert!(!is_existing); + } + + #[tokio::test] + async fn get_upstream_session_id() { + let max_subscribe_id = 10; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let upstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + + let session_id = pubsub_relation_manager + .get_upstream_session_id(track_namespace.clone()) + .await + .unwrap() + .unwrap(); + + assert_eq!(session_id, upstream_session_id); + } + + #[tokio::test] + async fn get_requesting_downstream_session_ids_and_subscribe_ids() { + let max_subscribe_id = 10; + let upstream_session_id = 1; + let downstream_session_ids = [2, 3]; + let downstream_subscribe_ids = [4, 5]; + let downstream_track_aliases = [6, 7]; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + let (upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); + + for i in [0, 1] { + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_ids[i]) + .await; + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_ids[i], + downstream_subscribe_ids[i], + downstream_track_aliases[i], + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + let _ = pubsub_relation_manager + .set_pubsub_relation( + upstream_session_id, + upstream_subscribe_id, + downstream_session_ids[i], + downstream_subscribe_ids[i], + ) + .await; + } + + let list = pubsub_relation_manager + .get_requesting_downstream_session_ids_and_subscribe_ids( + upstream_subscribe_id, + upstream_session_id, + ) + .await + .unwrap() + .unwrap(); + + let expected_list = vec![ + (downstream_session_ids[0], downstream_subscribe_ids[0]), + (downstream_session_ids[1], downstream_subscribe_ids[1]), + ]; + + assert_eq!(list, expected_list); + } + + #[tokio::test] + async fn get_upstream_subscribe_id() { + let max_subscribe_id = 10; + let upstream_session_id = 1; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + let (expected_upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); + + let upstream_subscribe_id = pubsub_relation_manager + .get_upstream_subscribe_id(track_namespace, track_name, upstream_session_id) + .await + .unwrap() + .unwrap(); - assert_eq!(announced_namespace, track_namespace); + assert_eq!(upstream_subscribe_id, expected_upstream_subscribe_id); } #[tokio::test] - async fn set_downstream_subscribed_namespace_prefix() { + async fn get_downstream_track_alias() { let max_subscribe_id = 10; - let track_namespace_prefix = Vec::from(["test".to_string(), "test".to_string()]); let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -859,73 +1708,83 @@ mod success { let _ = pubsub_relation_manager .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - let result = pubsub_relation_manager - .set_downstream_subscribed_namespace_prefix( - track_namespace_prefix.clone(), + let _ = pubsub_relation_manager + .set_downstream_subscription( downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, ) .await; - assert!(result.is_ok()); - - // Check if the track_namespace_prefix is set - let (_, producers, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let producer = producers.get(&downstream_session_id).unwrap(); - let subscribed_namespace_prefixes = producer.get_namespace_prefixes().unwrap(); - let subscribed_namespace_prefix = subscribed_namespace_prefixes.first().unwrap().to_vec(); + let result_track_alias = pubsub_relation_manager + .get_downstream_track_alias(downstream_session_id, subscribe_id) + .await + .unwrap() + .unwrap(); - assert_eq!(subscribed_namespace_prefix, track_namespace_prefix); + assert_eq!(result_track_alias, track_alias); } #[tokio::test] - async fn setup_subscriber() { + async fn get_subscribe_id_by_track_alias() { let max_subscribe_id = 10; - let downstream_session_id = 1; + let upstream_session_id = 1; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let track_alias = 0; + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let result = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) - .await; - assert!(result.is_ok()); - - // Check if the subscriber is created - let (_, producers, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let length = producers.len(); - - assert_eq!(length, 1); - } - - #[tokio::test] - async fn is_downstream_subscribe_id_unique_true() { - let max_subscribe_id = 10; - let subscribe_id = 1; - let downstream_session_id = 1; - - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) + .setup_publisher(max_subscribe_id, upstream_session_id) .await; - - let result = pubsub_relation_manager - .is_downstream_subscribe_id_unique(subscribe_id, downstream_session_id) + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; + let (expected_upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); - let is_unique = result.unwrap(); - assert!(is_unique); + let upstream_subscribe_id = pubsub_relation_manager + .get_upstream_subscribe_id_by_track_alias(upstream_session_id, track_alias) + .await + .unwrap() + .unwrap(); + + assert_eq!(upstream_subscribe_id, expected_upstream_subscribe_id); } #[tokio::test] - async fn is_downstream_subscribe_id_unique_false() { + async fn set_downstream_subscription() { let max_subscribe_id = 10; let downstream_session_id = 1; let subscribe_id = 0; @@ -938,7 +1797,6 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -948,36 +1806,58 @@ mod success { let _ = pubsub_relation_manager .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - let _ = pubsub_relation_manager + let result = pubsub_relation_manager .set_downstream_subscription( downstream_session_id, subscribe_id, track_alias, - track_namespace, - track_name, + track_namespace.clone(), + track_name.clone(), subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, ) .await; - let result = pubsub_relation_manager - .is_downstream_subscribe_id_unique(subscribe_id, downstream_session_id) - .await; + assert!(result.is_ok()); - let is_unique = result.unwrap(); - assert!(!is_unique); + // Assert that the subscription is set + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let producer = producers.get(&downstream_session_id).unwrap(); + let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); + + let expected_subscription = Subscription::new( + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + None, + ); + + assert_eq!(subscription, expected_subscription); } #[tokio::test] - async fn is_downstream_subscribe_id_less_than_max_true() { + async fn set_upstream_subscription() { let max_subscribe_id = 10; - let subscribe_id = 1; - let downstream_session_id = 1; + let upstream_session_id = 1; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -985,65 +1865,145 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) + .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let result = pubsub_relation_manager - .is_downstream_subscribe_id_less_than_max(subscribe_id, downstream_session_id) + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) .await; - let is_less = result.unwrap(); - assert!(is_less); + assert!(result.is_ok()); + + let (upstream_subscribe_id, upstream_track_alias) = result.unwrap(); + + // Assert that the subscription is set + let (consumers, _, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let consumer = consumers.get(&upstream_session_id).unwrap(); + let subscription = consumer + .get_subscription(upstream_subscribe_id) + .unwrap() + .unwrap(); + + let expected_subscription = Subscription::new( + upstream_track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + None, + ); + + assert_eq!(subscription, expected_subscription); } #[tokio::test] - async fn is_downstream_subscribe_id_less_than_max_false() { + async fn set_pubsub_relation() { let max_subscribe_id = 10; - let subscribe_id = 11; - let downstream_session_id = 1; + let upstream_session_id = 1; + let downstream_session_ids = [2, 3]; + let downstream_subscribe_ids = [4, 5]; + let downstream_track_aliases = [6, 7]; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + // pub 1 <- sub 2, 3 let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) + .setup_publisher(max_subscribe_id, upstream_session_id) .await; - - let result = pubsub_relation_manager - .is_downstream_subscribe_id_less_than_max(subscribe_id, downstream_session_id) + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; + let (upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); - let is_less = result.unwrap(); - assert!(!is_less); - } + for i in [0, 1] { + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_ids[i]) + .await; + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_ids[i], + downstream_subscribe_ids[i], + downstream_track_aliases[i], + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + let result = pubsub_relation_manager + .set_pubsub_relation( + upstream_session_id, + upstream_subscribe_id, + downstream_session_ids[i], + downstream_subscribe_ids[i], + ) + .await; - #[tokio::test] - async fn is_downstream_track_alias_unique_true() { - let max_subscribe_id = 10; - let track_alias = 1; - let downstream_session_id = 1; + assert!(result.is_ok()); + } - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + // Assert that the relation is registered + let (_, _, pubsub_relation) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) - .await; + let subscriber = pubsub_relation + .get_subscribers(upstream_session_id, upstream_subscribe_id) + .unwrap() + .to_vec(); - let result = pubsub_relation_manager - .is_downstream_track_alias_unique(track_alias, downstream_session_id) - .await; + let expected_subscriber = vec![ + (downstream_session_ids[0], downstream_subscribe_ids[0]), + (downstream_session_ids[1], downstream_subscribe_ids[1]), + ]; - let is_unique = result.unwrap(); - assert!(is_unique); + assert_eq!(subscriber, expected_subscriber); } #[tokio::test] - async fn is_unique_downstream_track_alias_false() { + async fn activate_downstream_subscription() { let max_subscribe_id = 10; let downstream_session_id = 1; let subscribe_id = 0; @@ -1056,7 +2016,6 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1079,21 +2038,27 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; - let result = pubsub_relation_manager - .is_downstream_track_alias_unique(track_alias, downstream_session_id) - .await; - assert!(result.is_ok()); + let activate_occured = pubsub_relation_manager + .activate_downstream_subscription(downstream_session_id, subscribe_id) + .await + .unwrap(); - let is_unique = result.unwrap(); - assert!(!is_unique); + assert!(activate_occured); + + // Assert that the subscription is active + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let producer = producers.get(&downstream_session_id).unwrap(); + let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); + + assert!(subscription.is_active()); } #[tokio::test] - async fn is_track_existing_exists() { + async fn activate_upstream_subscription() { let max_subscribe_id = 10; let upstream_session_id = 1; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); @@ -1104,23 +2069,16 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - - let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; - - let _ = pubsub_relation_manager + let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, track_namespace.clone(), @@ -1131,44 +2089,34 @@ mod success { start_group, start_object, end_group, - end_object, ) - .await; - - let result = pubsub_relation_manager - .is_track_existing(track_namespace, track_name) - .await; - assert!(result.is_ok()); - - let is_existing = result.unwrap(); - assert!(is_existing); - } + .await + .unwrap(); - #[tokio::test] - async fn is_track_existing_not_exists() { - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let track_name = "test_name".to_string(); + let activate_occured = pubsub_relation_manager + .activate_upstream_subscription(upstream_session_id, upstream_subscribe_id) + .await + .unwrap(); - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + assert!(activate_occured); - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let result = pubsub_relation_manager - .is_track_existing(track_namespace, track_name) - .await; - assert!(result.is_ok()); + // Assert that the subscription is active + let (consumers, _, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let consumer = consumers.get(&upstream_session_id).unwrap(); + let subscription = consumer + .get_subscription(upstream_subscribe_id) + .unwrap() + .unwrap(); - let is_existing = result.unwrap(); - assert!(!is_existing); + assert!(subscription.is_active()); } #[tokio::test] - async fn get_upstream_subscription_by_full_track_name() { + async fn get_upstream_namespaces_matches_prefix_exist() { let max_subscribe_id = 10; let upstream_session_id = 1; - let track_alias = 0; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; @@ -1176,7 +2124,7 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; + let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1186,6 +2134,9 @@ mod success { let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; let _ = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, @@ -1197,43 +2148,22 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; - let subscription = pubsub_relation_manager - .get_upstream_subscription_by_full_track_name( - track_namespace.clone(), - track_name.clone(), - ) + let namespaces = pubsub_relation_manager + .get_upstream_namespaces_matches_prefix(track_namespace_prefix) .await .unwrap(); - let forwarding_preference = None; - let expected_subscription = Subscription::new( - track_alias, - track_namespace, - track_name, - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - forwarding_preference, - ); - - assert_eq!(subscription, Some(expected_subscription)); + assert_eq!(namespaces, vec![track_namespace]); } #[tokio::test] - async fn get_upstream_subscription_by_ids() { + async fn get_upstream_namespaces_matches_prefix_not_exist() { let max_subscribe_id = 10; let upstream_session_id = 1; - let upstream_subscribe_id = 0; - let track_alias = 0; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; @@ -1241,7 +2171,7 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; + let track_namespace_prefix = Vec::from(["aa".to_string()]); // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1251,6 +2181,9 @@ mod success { let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; let _ = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, @@ -1262,48 +2195,123 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; - let subscription = pubsub_relation_manager - .get_upstream_subscription_by_ids(upstream_session_id, upstream_subscribe_id) + let namespaces = pubsub_relation_manager + .get_upstream_namespaces_matches_prefix(track_namespace_prefix) .await .unwrap(); - let forwarding_preference = None; - let expected_subscription = Subscription::new( - track_alias, - track_namespace, - track_name, - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - forwarding_preference, - ); + let expected_namespaces: Vec> = vec![]; + + assert_eq!(namespaces, expected_namespaces); + } + + #[tokio::test] + async fn is_namespace_announced_exist() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string()]); + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_downstream_announced_namespace(track_namespace.clone(), downstream_session_id) + .await; + + let result = pubsub_relation_manager + .is_namespace_announced(track_namespace.clone(), downstream_session_id) + .await; + + let is_announced = result.unwrap(); + assert!(is_announced); + } + + #[tokio::test] + async fn is_namespace_announced_not_exist() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string()]); + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + + let result = pubsub_relation_manager + .is_namespace_announced(track_namespace.clone(), downstream_session_id) + .await; + + let is_announced = result.unwrap(); + assert!(!is_announced); + } + + #[tokio::test] + async fn get_downstream_session_ids_by_upstream_namespace() { + let max_subscribe_id = 10; + let upstream_session_id = 1; + let downstream_session_ids = [2, 3]; + let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); + let track_namespace_prefixes = Vec::from([ + Vec::from(["aaa".to_string()]), + Vec::from(["bbb".to_string()]), + ]); + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + // pub 1 <- sub 2, 3 + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + + for i in [0, 1] { + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_ids[i]) + .await; + + let _ = pubsub_relation_manager + .set_downstream_subscribed_namespace_prefix( + track_namespace_prefixes[i].clone(), + downstream_session_ids[i], + ) + .await; + } + + let result = pubsub_relation_manager + .get_downstream_session_ids_by_upstream_namespace(track_namespace) + .await; - assert_eq!(subscription, Some(expected_subscription)); + assert!(result.is_ok()); + + let expected_downstream_session_ids = vec![downstream_session_ids[0]]; + + assert_eq!(result.unwrap(), expected_downstream_session_ids); } #[tokio::test] - async fn get_downstream_subscription_by_ids() { + async fn delete_upstream_announced_namespace() { let max_subscribe_id = 10; - let downstream_session_id = 1; - let downstream_subscribe_id = 1; - let track_alias = 0; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let track_name = "track_name".to_string(); - let subscriber_priority = 0; - let group_order = GroupOrder::Ascending; - let filter_type = FilterType::AbsoluteStart; - let start_group = Some(0); - let start_object = Some(0); - let end_group = None; - let end_object = None; + let upstream_session_id = 1; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1311,50 +2319,31 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) + .setup_publisher(max_subscribe_id, upstream_session_id) .await; let _ = pubsub_relation_manager - .set_downstream_subscription( - downstream_session_id, - downstream_subscribe_id, - track_alias, - track_namespace.clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; - let subscription = pubsub_relation_manager - .get_downstream_subscription_by_ids(downstream_session_id, downstream_subscribe_id) - .await - .unwrap(); + let result = pubsub_relation_manager + .delete_upstream_announced_namespace(track_namespace, upstream_session_id) + .await; + assert!(result.is_ok()); - let forwarding_preference = None; - let expected_subscription = Subscription::new( - track_alias, - track_namespace, - track_name, - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - forwarding_preference, - ); + let delete_occured = result.unwrap(); + assert!(delete_occured); + + // Assert that the announced namespace is deleted + let (consumers, _, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let consumer = consumers.get(&upstream_session_id).unwrap(); + let announced_namespaces = consumer.get_namespaces().unwrap().to_vec(); - assert_eq!(subscription, Some(expected_subscription)); + assert!(announced_namespaces.is_empty()); } #[tokio::test] - async fn get_upstream_session_id() { + async fn delete_upstream_announced_namespace_not_exists() { let max_subscribe_id = 10; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let upstream_session_id = 1; @@ -1368,35 +2357,34 @@ mod success { let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + let result = pubsub_relation_manager + .delete_upstream_announced_namespace(track_namespace, upstream_session_id) .await; + assert!(result.is_ok()); - let session_id = pubsub_relation_manager - .get_upstream_session_id(track_namespace.clone()) - .await - .unwrap() - .unwrap(); - - assert_eq!(session_id, upstream_session_id); + let delete_occured = result.unwrap(); + assert!(!delete_occured); } #[tokio::test] - async fn get_requesting_downstream_session_ids_and_subscribe_ids() { + async fn delete_client() { let max_subscribe_id = 10; - let upstream_session_id = 1; - let downstream_session_ids = [2, 3]; - let downstream_subscribe_ids = [4, 5]; - let downstream_track_aliases = [6, 7]; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let track_name = "track_name".to_string(); + let track_namespaces = [ + Vec::from(["test1".to_string(), "test1".to_string()]), + Vec::from(["test2".to_string(), "test2".to_string()]), + ]; + let upstream_session_ids = [1, 2]; + let mut upstream_subscribe_ids = vec![]; + let downstream_session_ids = [2, 3, 4]; + let downstream_subscribe_ids = [2, 3, 4]; + let downstream_track_aliases = [2, 3, 4]; + let track_name = "test_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; let filter_type = FilterType::AbsoluteStart; let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1404,38 +2392,24 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) - .await; - let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; - let (upstream_subscribe_id, _) = pubsub_relation_manager - .set_upstream_subscription( - upstream_session_id, - track_namespace.clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) - .await - .unwrap(); - + // Register: + // pub 1 <- sub 2, 3, 4 + // pub 2 <- sub 3, 4 for i in [0, 1] { + // for pub 1, 2 let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_ids[i]) + .setup_publisher(max_subscribe_id, upstream_session_ids[i]) .await; let _ = pubsub_relation_manager - .set_downstream_subscription( - downstream_session_ids[i], - downstream_subscribe_ids[i], - downstream_track_aliases[i], - track_namespace.clone(), + .set_upstream_announced_namespace( + track_namespaces[i].clone(), + upstream_session_ids[i], + ) + .await; + let (upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_ids[i], + track_namespaces[i].clone(), track_name.clone(), subscriber_priority, group_order, @@ -1443,49 +2417,146 @@ mod success { start_group, start_object, end_group, - end_object, ) - .await; + .await + .unwrap(); + upstream_subscribe_ids.push(upstream_subscribe_id); + } + + for j in [0, 1, 2] { + // for sub 2, 3, 4 let _ = pubsub_relation_manager - .set_pubsub_relation( - upstream_session_id, - upstream_subscribe_id, - downstream_session_ids[i], - downstream_subscribe_ids[i], - ) + .setup_subscriber(max_subscribe_id, downstream_session_ids[j]) .await; } - let list = pubsub_relation_manager - .get_requesting_downstream_session_ids_and_subscribe_ids( - upstream_subscribe_id, - upstream_session_id, - ) - .await - .unwrap() - .unwrap(); + // for sub 2 + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_ids[0], + downstream_subscribe_ids[0], + downstream_track_aliases[0], + track_namespaces[0].clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + + for i in [0, 1] { + // for pub 1, 2 + for j in [1, 2] { + // for sub 3, 4 + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_ids[j], + downstream_subscribe_ids[j], + downstream_track_aliases[j], + track_namespaces[i].clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await; + + let _ = pubsub_relation_manager + .set_pubsub_relation( + upstream_session_ids[i], + upstream_subscribe_ids[i], + downstream_session_ids[j], + downstream_subscribe_ids[j], + ) + .await; + let _ = pubsub_relation_manager + .activate_downstream_subscription( + downstream_session_ids[j], + downstream_subscribe_ids[j], + ) + .await; + + let _ = pubsub_relation_manager + .activate_upstream_subscription( + upstream_session_ids[i], + upstream_subscribe_ids[i], + ) + .await; + } + } + + // for pub 1 and sub 2 + let _ = pubsub_relation_manager + .set_pubsub_relation( + upstream_session_ids[0], + upstream_subscribe_ids[0], + downstream_session_ids[0], + downstream_subscribe_ids[0], + ) + .await; + let _ = pubsub_relation_manager + .activate_downstream_subscription( + downstream_session_ids[0], + downstream_subscribe_ids[0], + ) + .await; + + let _ = pubsub_relation_manager + .activate_upstream_subscription(upstream_session_ids[0], upstream_subscribe_ids[0]) + .await; + + // Delete: pub 2, sub 2 + // Remain: pub 1 <- sub 3, 4 + let result = pubsub_relation_manager + .delete_client(downstream_session_ids[0]) + .await; + assert!(result.is_ok()); + + let delete_occured = result.unwrap(); + assert!(delete_occured); + + let (consumers, producers, pubsub_relation) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + + // Assert that sub 2 is deleted + // Remain: sub 3, 4 + let sub2 = producers.get(&downstream_session_ids[0]); + assert!(sub2.is_none()); + + let sub3 = producers.get(&downstream_session_ids[1]); + assert!(sub3.is_some()); - let expected_list = vec![ - (downstream_session_ids[0], downstream_subscribe_ids[0]), - (downstream_session_ids[1], downstream_subscribe_ids[1]), - ]; + let sub4 = producers.get(&downstream_session_ids[2]); + assert!(sub4.is_some()); - assert_eq!(list, expected_list); + // Assert that pub 2 is deleted + // Remain: pub 1 + let pub1 = consumers.get(&upstream_session_ids[1]); + assert!(pub1.is_none()); + + let pub2 = consumers.get(&upstream_session_ids[0]); + assert!(pub2.is_some()); + + // Assert that the relation is deleted + // Remain: pub 1 <- sub 3, 4 + let pub1_relation = + pubsub_relation.get_subscribers(upstream_session_ids[0], upstream_subscribe_ids[0]); + assert!(pub1_relation.is_some()); + + let pub2_relation = + pubsub_relation.get_subscribers(upstream_session_ids[1], upstream_subscribe_ids[1]); + assert!(pub2_relation.is_none()); } #[tokio::test] - async fn get_upstream_subscribe_id() { - let max_subscribe_id = 10; - let upstream_session_id = 1; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let track_name = "track_name".to_string(); - let subscriber_priority = 0; - let group_order = GroupOrder::Ascending; - let filter_type = FilterType::AbsoluteStart; - let start_group = Some(0); - let start_object = Some(0); - let end_group = None; - let end_object = None; + async fn delete_client_not_exists() { + let session_id = 1; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1493,43 +2564,19 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) - .await; - let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; - let (expected_upstream_subscribe_id, _) = pubsub_relation_manager - .set_upstream_subscription( - upstream_session_id, - track_namespace.clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) - .await - .unwrap(); - - let upstream_subscribe_id = pubsub_relation_manager - .get_upstream_subscribe_id(track_namespace, track_name, upstream_session_id) - .await - .unwrap() - .unwrap(); + let result = pubsub_relation_manager.delete_client(session_id).await; + assert!(result.is_ok()); - assert_eq!(upstream_subscribe_id, expected_upstream_subscribe_id); + let delete_occured = result.unwrap(); + assert!(!delete_occured); } #[tokio::test] - async fn set_downstream_subscription() { + async fn delete_pubsub_relation() { let max_subscribe_id = 10; - let downstream_session_id = 1; - let subscribe_id = 0; - let track_alias = 0; + let upstream_session_id = 1; + let downstream_session_id = 2; + let downstream_subscribe_id = 3; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; @@ -1538,21 +2585,26 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; let _ = pubsub_relation_manager .setup_subscriber(max_subscribe_id, downstream_session_id) .await; - let result = pubsub_relation_manager + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + let _ = pubsub_relation_manager .set_downstream_subscription( downstream_session_id, - subscribe_id, - track_alias, + max_subscribe_id, + 0, track_namespace.clone(), track_name.clone(), subscriber_priority, @@ -1561,58 +2613,9 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; - - assert!(result.is_ok()); - - // Assert that the subscription is set - let (_, producers, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let producer = producers.get(&downstream_session_id).unwrap(); - let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); - - let expected_subscription = Subscription::new( - track_alias, - track_namespace, - track_name, - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - None, - ); - - assert_eq!(subscription, expected_subscription); - } - - #[tokio::test] - async fn set_upstream_subscription() { - let max_subscribe_id = 10; - let upstream_session_id = 1; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let track_name = "track_name".to_string(); - let subscriber_priority = 0; - let group_order = GroupOrder::Ascending; - let filter_type = FilterType::AbsoluteStart; - let start_group = Some(0); - let start_object = Some(0); - let end_group = None; - let end_object = None; - - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); - - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) - .await; - let result = pubsub_relation_manager + let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, track_namespace.clone(), @@ -1623,47 +2626,43 @@ mod success { start_group, start_object, end_group, - end_object, + ) + .await + .unwrap(); + let _ = pubsub_relation_manager + .set_pubsub_relation( + upstream_session_id, + upstream_subscribe_id, + downstream_session_id, + downstream_subscribe_id, ) .await; - assert!(result.is_ok()); + let result = pubsub_relation_manager + .delete_pubsub_relation( + upstream_session_id, + upstream_subscribe_id, + downstream_session_id, + downstream_subscribe_id, + ) + .await; - let (upstream_subscribe_id, upstream_track_alias) = result.unwrap(); + assert!(result.is_ok()); - // Assert that the subscription is set - let (consumers, _, _) = + let (_, _, pubsub_relation) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let consumer = consumers.get(&upstream_session_id).unwrap(); - let subscription = consumer - .get_subscription(upstream_subscribe_id) - .unwrap() - .unwrap(); - let expected_subscription = Subscription::new( - upstream_track_alias, - track_namespace, - track_name, - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - None, - ); + let relation = pubsub_relation + .get_subscribers(upstream_session_id, upstream_subscribe_id) + .unwrap(); - assert_eq!(subscription, expected_subscription); + assert!(relation.is_empty()); } #[tokio::test] - async fn set_pubsub_relation() { + async fn delete_upstream_subscription() { let max_subscribe_id = 10; let upstream_session_id = 1; - let downstream_session_ids = [2, 3]; - let downstream_subscribe_ids = [4, 5]; - let downstream_track_aliases = [6, 7]; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; @@ -1672,21 +2671,15 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - - // pub 1 <- sub 2, 3 let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, @@ -1698,62 +2691,25 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); - for i in [0, 1] { - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_ids[i]) - .await; - let _ = pubsub_relation_manager - .set_downstream_subscription( - downstream_session_ids[i], - downstream_subscribe_ids[i], - downstream_track_aliases[i], - track_namespace.clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) - .await; - let result = pubsub_relation_manager - .set_pubsub_relation( - upstream_session_id, - upstream_subscribe_id, - downstream_session_ids[i], - downstream_subscribe_ids[i], - ) - .await; - - assert!(result.is_ok()); - } + let result = pubsub_relation_manager + .delete_upstream_subscription(upstream_session_id, upstream_subscribe_id) + .await; + assert!(result.is_ok()); - // Assert that the relation is registered - let (_, _, pubsub_relation) = + let (consumers, _, _) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let consumer = consumers.get(&upstream_session_id).unwrap(); + let subscription = consumer.get_subscription(upstream_subscribe_id).unwrap(); - let subscriber = pubsub_relation - .get_subscribers(upstream_session_id, upstream_subscribe_id) - .unwrap() - .to_vec(); - - let expected_subscriber = vec![ - (downstream_session_ids[0], downstream_subscribe_ids[0]), - (downstream_session_ids[1], downstream_subscribe_ids[1]), - ]; - - assert_eq!(subscriber, expected_subscriber); + assert!(subscription.is_none()); } #[tokio::test] - async fn activate_downstream_subscription() { + async fn delete_downstream_subscription() { let max_subscribe_id = 10; let downstream_session_id = 1; let subscribe_id = 0; @@ -1766,7 +2722,6 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1781,36 +2736,32 @@ mod success { downstream_session_id, subscribe_id, track_alias, - track_namespace, - track_name, + track_namespace.clone(), + track_name.clone(), subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, ) .await; - let activate_occured = pubsub_relation_manager - .activate_downstream_subscription(downstream_session_id, subscribe_id) - .await - .unwrap(); - - assert!(activate_occured); + let result = pubsub_relation_manager + .delete_downstream_subscription(downstream_session_id, subscribe_id) + .await; + assert!(result.is_ok()); - // Assert that the subscription is active let (_, producers, _) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; let producer = producers.get(&downstream_session_id).unwrap(); - let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); + let subscription = producer.get_subscription(subscribe_id).unwrap(); - assert!(subscription.is_active()); + assert!(subscription.is_none()); } #[tokio::test] - async fn activate_upstream_subscription() { + async fn set_upstream_forwarding_preference() { let max_subscribe_id = 10; let upstream_session_id = 1; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); @@ -1821,7 +2772,7 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; + let forwarding_preference = ForwardingPreference::Subgroup; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1831,6 +2782,9 @@ mod success { let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, @@ -1842,19 +2796,19 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); - let activate_occured = pubsub_relation_manager - .activate_upstream_subscription(upstream_session_id, upstream_subscribe_id) - .await - .unwrap(); - - assert!(activate_occured); + let result = pubsub_relation_manager + .set_upstream_forwarding_preference( + upstream_session_id, + upstream_subscribe_id, + forwarding_preference.clone(), + ) + .await; + assert!(result.is_ok()); - // Assert that the subscription is active let (consumers, _, _) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; let consumer = consumers.get(&upstream_session_id).unwrap(); @@ -1863,14 +2817,16 @@ mod success { .unwrap() .unwrap(); - assert!(subscription.is_active()); + let result_forwarding_preference = subscription.get_forwarding_preference().unwrap(); + + assert_eq!(result_forwarding_preference, forwarding_preference); } #[tokio::test] - async fn get_upstream_namespaces_matches_prefix_exist() { + async fn get_upstream_forwarding_preference() { let max_subscribe_id = 10; let upstream_session_id = 1; - let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; @@ -1878,8 +2834,7 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; - let track_namespace_prefix = Vec::from(["aaa".to_string(), "bbb".to_string()]); + let forwarding_preference = ForwardingPreference::Subgroup; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1892,7 +2847,7 @@ mod success { let _ = pubsub_relation_manager .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; - let _ = pubsub_relation_manager + let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, track_namespace.clone(), @@ -1903,23 +2858,34 @@ mod success { start_group, start_object, end_group, - end_object, ) - .await; - - let namespaces = pubsub_relation_manager - .get_upstream_namespaces_matches_prefix(track_namespace_prefix) .await .unwrap(); + let _ = pubsub_relation_manager + .set_upstream_forwarding_preference( + upstream_session_id, + upstream_subscribe_id, + forwarding_preference.clone(), + ) + .await; - assert_eq!(namespaces, vec![track_namespace]); + let result = pubsub_relation_manager + .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) + .await; + assert!(result.is_ok()); + + let result_forwarding_preference = result.unwrap().unwrap(); + + assert_eq!(result_forwarding_preference, forwarding_preference); } #[tokio::test] - async fn get_upstream_namespaces_matches_prefix_not_exist() { + async fn set_downstream_forwarding_preference() { let max_subscribe_id = 10; - let upstream_session_id = 1; - let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; @@ -1927,8 +2893,7 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; - let track_namespace_prefix = Vec::from(["aa".to_string()]); + let forwarding_preference = ForwardingPreference::Subgroup; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -1936,14 +2901,13 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) - .await; - let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .setup_subscriber(max_subscribe_id, downstream_session_id) .await; let _ = pubsub_relation_manager - .set_upstream_subscription( - upstream_session_id, + .set_downstream_subscription( + downstream_session_id, + subscribe_id, + track_alias, track_namespace.clone(), track_name.clone(), subscriber_priority, @@ -1952,124 +2916,90 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; - let namespaces = pubsub_relation_manager - .get_upstream_namespaces_matches_prefix(track_namespace_prefix) - .await - .unwrap(); - - let expected_namespaces: Vec> = vec![]; - - assert_eq!(namespaces, expected_namespaces); - } - - #[tokio::test] - async fn is_namespace_announced_exist() { - let max_subscribe_id = 10; - let downstream_session_id = 1; - let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string()]); - - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); - - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) - .await; - let _ = pubsub_relation_manager - .set_downstream_announced_namespace(track_namespace.clone(), downstream_session_id) - .await; - let result = pubsub_relation_manager - .is_namespace_announced(track_namespace.clone(), downstream_session_id) + .set_downstream_forwarding_preference( + downstream_session_id, + subscribe_id, + forwarding_preference.clone(), + ) .await; + assert!(result.is_ok()); - let is_announced = result.unwrap(); - assert!(is_announced); - } - - #[tokio::test] - async fn is_namespace_announced_not_exist() { - let max_subscribe_id = 10; - let downstream_session_id = 1; - let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string()]); - - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); - - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) - .await; + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let producer = producers.get(&downstream_session_id).unwrap(); + let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); - let result = pubsub_relation_manager - .is_namespace_announced(track_namespace.clone(), downstream_session_id) - .await; + let result_forwarding_preference = subscription.get_forwarding_preference().unwrap(); - let is_announced = result.unwrap(); - assert!(!is_announced); + assert_eq!(result_forwarding_preference, forwarding_preference); } #[tokio::test] - async fn get_downstream_session_ids_by_upstream_namespace() { + async fn get_upstream_filter_type() { let max_subscribe_id = 10; let upstream_session_id = 1; - let downstream_session_ids = [2, 3]; - let track_namespace = Vec::from(["aaa".to_string(), "bbb".to_string(), "ccc".to_string()]); - let track_namespace_prefixes = Vec::from([ - Vec::from(["aaa".to_string()]), - Vec::from(["bbb".to_string()]), - ]); + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - - // pub 1 <- sub 2, 3 let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let _ = pubsub_relation_manager .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; + let (upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); - for i in [0, 1] { - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_ids[i]) - .await; - - let _ = pubsub_relation_manager - .set_downstream_subscribed_namespace_prefix( - track_namespace_prefixes[i].clone(), - downstream_session_ids[i], - ) - .await; - } - - let result = pubsub_relation_manager - .get_downstream_session_ids_by_upstream_namespace(track_namespace) - .await; - - assert!(result.is_ok()); - - let expected_downstream_session_ids = vec![downstream_session_ids[0]]; + let result_filter_type = pubsub_relation_manager + .get_upstream_filter_type(upstream_session_id, upstream_subscribe_id) + .await + .unwrap() + .unwrap(); - assert_eq!(result.unwrap(), expected_downstream_session_ids); + assert_eq!(result_filter_type, filter_type); } #[tokio::test] - async fn delete_upstream_announced_namespace() { + async fn get_downstream_filter_type() { let max_subscribe_id = 10; + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); - let upstream_session_id = 1; + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2077,268 +3007,196 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) + .setup_subscriber(max_subscribe_id, downstream_session_id) .await; let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; - - let result = pubsub_relation_manager - .delete_upstream_announced_namespace(track_namespace, upstream_session_id) + .set_downstream_subscription( + downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) .await; - assert!(result.is_ok()); - let delete_occured = result.unwrap(); - assert!(delete_occured); - - // Assert that the announced namespace is deleted - let (consumers, _, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let consumer = consumers.get(&upstream_session_id).unwrap(); - let announced_namespaces = consumer.get_namespaces().unwrap().to_vec(); + let result_filter_type = pubsub_relation_manager + .get_downstream_filter_type(downstream_session_id, subscribe_id) + .await + .unwrap() + .unwrap(); - assert!(announced_namespaces.is_empty()); + assert_eq!(result_filter_type, filter_type); } #[tokio::test] - async fn delete_upstream_announced_namespace_not_exists() { + async fn get_upstream_requested_object_range() { let max_subscribe_id = 10; - let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let upstream_session_id = 1; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let result = pubsub_relation_manager - .delete_upstream_announced_namespace(track_namespace, upstream_session_id) + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; - assert!(result.is_ok()); + let (upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); - let delete_occured = result.unwrap(); - assert!(!delete_occured); + let result_range = pubsub_relation_manager + .get_upstream_requested_object_range(upstream_session_id, upstream_subscribe_id) + .await + .unwrap() + .unwrap(); + + assert_eq!(result_range.start_group(), start_group); + assert_eq!(result_range.start_object(), start_object); + assert_eq!(result_range.end_group(), end_group); } #[tokio::test] - async fn delete_client() { + async fn get_downstream_requested_object_range() { let max_subscribe_id = 10; - let track_namespaces = [ - Vec::from(["test1".to_string(), "test1".to_string()]), - Vec::from(["test2".to_string(), "test2".to_string()]), - ]; - let upstream_session_ids = [1, 2]; - let mut upstream_subscribe_ids = vec![]; - let downstream_session_ids = [2, 3, 4]; - let downstream_subscribe_ids = [2, 3, 4]; - let downstream_track_aliases = [2, 3, 4]; - let track_name = "test_name".to_string(); + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); let subscriber_priority = 0; let group_order = GroupOrder::Ascending; let filter_type = FilterType::AbsoluteStart; let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; - // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - - // Register: - // pub 1 <- sub 2, 3, 4 - // pub 2 <- sub 3, 4 - for i in [0, 1] { - // for pub 1, 2 - let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_ids[i]) - .await; - let _ = pubsub_relation_manager - .set_upstream_announced_namespace( - track_namespaces[i].clone(), - upstream_session_ids[i], - ) - .await; - let (upstream_subscribe_id, _) = pubsub_relation_manager - .set_upstream_subscription( - upstream_session_ids[i], - track_namespaces[i].clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) - .await - .unwrap(); - upstream_subscribe_ids.push(upstream_subscribe_id); - } - - for j in [0, 1, 2] { - // for sub 2, 3, 4 - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_ids[j]) - .await; - } - - // for sub 2 + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; let _ = pubsub_relation_manager .set_downstream_subscription( - downstream_session_ids[0], - downstream_subscribe_ids[0], - downstream_track_aliases[0], - track_namespaces[0].clone(), - track_name.clone(), + downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, ) .await; - for i in [0, 1] { - // for pub 1, 2 - for j in [1, 2] { - // for sub 3, 4 - let _ = pubsub_relation_manager - .set_downstream_subscription( - downstream_session_ids[j], - downstream_subscribe_ids[j], - downstream_track_aliases[j], - track_namespaces[i].clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) - .await; + let result_range = pubsub_relation_manager + .get_downstream_requested_object_range(downstream_session_id, subscribe_id) + .await + .unwrap() + .unwrap(); - let _ = pubsub_relation_manager - .set_pubsub_relation( - upstream_session_ids[i], - upstream_subscribe_ids[i], - downstream_session_ids[j], - downstream_subscribe_ids[j], - ) - .await; - let _ = pubsub_relation_manager - .activate_downstream_subscription( - downstream_session_ids[j], - downstream_subscribe_ids[j], - ) - .await; + assert_eq!(result_range.start_group(), start_group); + assert_eq!(result_range.start_object(), start_object); + assert_eq!(result_range.end_group(), end_group); + } - let _ = pubsub_relation_manager - .activate_upstream_subscription( - upstream_session_ids[i], - upstream_subscribe_ids[i], - ) - .await; - } - } + #[tokio::test] + async fn downstream_actual_object_start() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::LatestObject; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + let actual_object_start = ObjectStart::new(1, 1); - // for pub 1 and sub 2 + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .set_pubsub_relation( - upstream_session_ids[0], - upstream_subscribe_ids[0], - downstream_session_ids[0], - downstream_subscribe_ids[0], - ) + .setup_subscriber(max_subscribe_id, downstream_session_id) .await; let _ = pubsub_relation_manager - .activate_downstream_subscription( - downstream_session_ids[0], - downstream_subscribe_ids[0], + .set_downstream_subscription( + downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, ) .await; - let _ = pubsub_relation_manager - .activate_upstream_subscription(upstream_session_ids[0], upstream_subscribe_ids[0]) - .await; - - // Delete: pub 2, sub 2 - // Remain: pub 1 <- sub 3, 4 let result = pubsub_relation_manager - .delete_client(downstream_session_ids[0]) + .set_downstream_actual_object_start( + downstream_session_id, + subscribe_id, + actual_object_start.clone(), + ) .await; assert!(result.is_ok()); - let delete_occured = result.unwrap(); - assert!(delete_occured); - - let (consumers, producers, pubsub_relation) = + // Assert that the actual start is set + let (_, producers, _) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let producer = producers.get(&downstream_session_id).unwrap(); + let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); - // Assert that sub 2 is deleted - // Remain: sub 3, 4 - let sub2 = producers.get(&downstream_session_ids[0]); - assert!(sub2.is_none()); - - let sub3 = producers.get(&downstream_session_ids[1]); - assert!(sub3.is_some()); - - let sub4 = producers.get(&downstream_session_ids[2]); - assert!(sub4.is_some()); - - // Assert that pub 2 is deleted - // Remain: pub 1 - let pub1 = consumers.get(&upstream_session_ids[1]); - assert!(pub1.is_none()); - - let pub2 = consumers.get(&upstream_session_ids[0]); - assert!(pub2.is_some()); - - // Assert that the relation is deleted - // Remain: pub 1 <- sub 3, 4 - let pub1_relation = - pubsub_relation.get_subscribers(upstream_session_ids[0], upstream_subscribe_ids[0]); - assert!(pub1_relation.is_some()); - - let pub2_relation = - pubsub_relation.get_subscribers(upstream_session_ids[1], upstream_subscribe_ids[1]); - assert!(pub2_relation.is_none()); - } - - #[tokio::test] - async fn delete_client_not_exists() { - let session_id = 1; - - // Start track management thread - let (track_tx, mut track_rx) = mpsc::channel::(1024); - tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); - - let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); - - let result = pubsub_relation_manager.delete_client(session_id).await; - assert!(result.is_ok()); + let result_actual_object_start = subscription.get_actual_object_start().unwrap(); - let delete_occured = result.unwrap(); - assert!(!delete_occured); + assert_eq!(result_actual_object_start, actual_object_start); } #[tokio::test] - async fn delete_pubsub_relation() { + async fn set_upstream_stream_id() { let max_subscribe_id = 10; - let upstream_session_id = 1; - let downstream_session_id = 2; - let downstream_subscribe_id = 3; + let upstream_session_id = 1; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; @@ -2347,7 +3205,9 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; + let group_id = 2; + let subgroup_id = 3; + let stream_id = 4; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2357,75 +3217,53 @@ mod success { let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let _ = pubsub_relation_manager - .setup_subscriber(max_subscribe_id, downstream_session_id) - .await; let _ = pubsub_relation_manager .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) .await; - let _ = pubsub_relation_manager - .set_downstream_subscription( - downstream_session_id, - max_subscribe_id, - 0, - track_namespace.clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) - .await; let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, - track_namespace.clone(), - track_name.clone(), + track_namespace, + track_name, subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, ) .await .unwrap(); - let _ = pubsub_relation_manager - .set_pubsub_relation( - upstream_session_id, - upstream_subscribe_id, - downstream_session_id, - downstream_subscribe_id, - ) - .await; - let result = pubsub_relation_manager - .delete_pubsub_relation( + let _ = pubsub_relation_manager + .set_upstream_stream_id( upstream_session_id, upstream_subscribe_id, - downstream_session_id, - downstream_subscribe_id, + group_id, + subgroup_id, + stream_id, ) .await; - assert!(result.is_ok()); - - let (_, _, pubsub_relation) = + let (consumers, _, _) = test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - - let relation = pubsub_relation - .get_subscribers(upstream_session_id, upstream_subscribe_id) + let consumer = consumers.get(&upstream_session_id).unwrap(); + let subscription = consumer + .get_subscription(upstream_subscribe_id) + .unwrap() .unwrap(); - assert!(relation.is_empty()); + let result_subgroup_id = subscription.get_subgroup_ids_for_group(group_id)[0]; + assert_eq!(result_subgroup_id, subgroup_id); + + let result_stream_id = subscription + .get_stream_id_for_subgroup(group_id, result_subgroup_id) + .unwrap(); + assert_eq!(result_stream_id, stream_id); } #[tokio::test] - async fn delete_upstream_subscription() { + async fn get_upstream_subscribe_ids_for_client() { let max_subscribe_id = 10; let upstream_session_id = 1; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); @@ -2436,7 +3274,6 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2446,37 +3283,42 @@ mod success { let _ = pubsub_relation_manager .setup_publisher(max_subscribe_id, upstream_session_id) .await; - let (upstream_subscribe_id, _) = pubsub_relation_manager - .set_upstream_subscription( - upstream_session_id, - track_namespace.clone(), - track_name.clone(), - subscriber_priority, - group_order, - filter_type, - start_group, - start_object, - end_group, - end_object, - ) + + let mut upstream_subscribe_ids: Vec = vec![]; + for _ in 0..3 { + let _ = pubsub_relation_manager + .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) + .await; + let (upstream_subscribe_id, _) = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + ) + .await + .unwrap(); + + upstream_subscribe_ids.push(upstream_subscribe_id); + } + + let mut result_subscribe_ids = pubsub_relation_manager + .get_upstream_subscribe_ids_for_client(upstream_session_id) .await .unwrap(); - let result = pubsub_relation_manager - .delete_upstream_subscription(upstream_session_id, upstream_subscribe_id) - .await; - assert!(result.is_ok()); - - let (consumers, _, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let consumer = consumers.get(&upstream_session_id).unwrap(); - let subscription = consumer.get_subscription(upstream_subscribe_id).unwrap(); + result_subscribe_ids.sort(); - assert!(subscription.is_none()); + assert_eq!(result_subscribe_ids, upstream_subscribe_ids); } #[tokio::test] - async fn delete_downstream_subscription() { + async fn get_downstream_group_ids_for_subscription() { let max_subscribe_id = 10; let downstream_session_id = 1; let subscribe_id = 0; @@ -2489,7 +3331,9 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; + let group_ids: Vec = vec![2, 3, 4]; + let subgroup_ids: Vec = vec![5, 6, 7]; + let stream_ids: Vec = vec![8, 9, 10]; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2512,25 +3356,31 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; - let result = pubsub_relation_manager - .delete_downstream_subscription(downstream_session_id, subscribe_id) - .await; - assert!(result.is_ok()); + for i in 0..group_ids.len() { + let _ = pubsub_relation_manager + .set_downstream_stream_id( + downstream_session_id, + subscribe_id, + group_ids[i], + subgroup_ids[i], + stream_ids[i], + ) + .await; + } - let (_, producers, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let producer = producers.get(&downstream_session_id).unwrap(); - let subscription = producer.get_subscription(subscribe_id).unwrap(); + let result_group_ids = pubsub_relation_manager + .get_downstream_group_ids_for_subscription(downstream_session_id, subscribe_id) + .await + .unwrap(); - assert!(subscription.is_none()); + assert_eq!(result_group_ids, group_ids); } #[tokio::test] - async fn set_upstream_forwarding_preference() { + async fn get_upstream_stream_ids_from_group() { let max_subscribe_id = 10; let upstream_session_id = 1; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); @@ -2541,8 +3391,9 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; - let forwarding_preference = ForwardingPreference::Track; + let group_id = 2; + let subgroup_ids: Vec = vec![3, 4, 5]; + let stream_ids: Vec = vec![6, 7, 8]; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2558,46 +3409,63 @@ mod success { let (upstream_subscribe_id, _) = pubsub_relation_manager .set_upstream_subscription( upstream_session_id, - track_namespace.clone(), - track_name.clone(), + track_namespace, + track_name, subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, ) .await .unwrap(); - let result = pubsub_relation_manager - .set_upstream_forwarding_preference( + for i in 0..subgroup_ids.len() { + let _ = pubsub_relation_manager + .set_upstream_stream_id( + upstream_session_id, + upstream_subscribe_id, + group_id, + subgroup_ids[i], + stream_ids[i], + ) + .await; + } + + let result_subgroup_ids = pubsub_relation_manager + .get_upstream_subgroup_ids_for_group( upstream_session_id, upstream_subscribe_id, - forwarding_preference.clone(), + group_id, ) - .await; - assert!(result.is_ok()); - - // Assert that the forwarding preference is set - let (consumers, _, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let consumer = consumers.get(&upstream_session_id).unwrap(); - let subscription = consumer - .get_subscription(upstream_subscribe_id) - .unwrap() + .await .unwrap(); - let result_forwarding_preference = subscription.get_forwarding_preference().unwrap(); + assert_eq!(result_subgroup_ids, subgroup_ids); - assert_eq!(result_forwarding_preference, forwarding_preference); + for i in 0..subgroup_ids.len() { + let result_stream_id = pubsub_relation_manager + .get_upstream_stream_id_for_subgroup( + upstream_session_id, + upstream_subscribe_id, + group_id, + result_subgroup_ids[i], + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(result_stream_id, stream_ids[i]); + } } #[tokio::test] - async fn get_upstream_forwarding_preference() { + async fn set_downstream_stream_id() { let max_subscribe_id = 10; - let upstream_session_id = 1; + let downstream_session_id = 1; + let subscribe_id = 0; + let track_alias = 0; let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); let track_name = "track_name".to_string(); let subscriber_priority = 0; @@ -2606,8 +3474,9 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; - let forwarding_preference = ForwardingPreference::Track; + let group_id = 2; + let subgroup_id = 3; + let stream_id = 4; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2615,46 +3484,93 @@ mod success { let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); let _ = pubsub_relation_manager - .setup_publisher(max_subscribe_id, upstream_session_id) + .setup_subscriber(max_subscribe_id, downstream_session_id) .await; let _ = pubsub_relation_manager - .set_upstream_announced_namespace(track_namespace.clone(), upstream_session_id) - .await; - let (upstream_subscribe_id, _) = pubsub_relation_manager - .set_upstream_subscription( - upstream_session_id, - track_namespace.clone(), - track_name.clone(), + .set_downstream_subscription( + downstream_session_id, + subscribe_id, + track_alias, + track_namespace, + track_name, subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, ) - .await - .unwrap(); + .await; + let _ = pubsub_relation_manager - .set_upstream_forwarding_preference( - upstream_session_id, - upstream_subscribe_id, - forwarding_preference.clone(), + .set_downstream_stream_id( + downstream_session_id, + subscribe_id, + group_id, + subgroup_id, + stream_id, ) .await; - let result = pubsub_relation_manager - .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) - .await; - assert!(result.is_ok()); + let (_, producers, _) = + test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; + let producer = producers.get(&downstream_session_id).unwrap(); + let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); - let result_forwarding_preference = result.unwrap().unwrap(); + let result_subgroup_id = subscription.get_subgroup_ids_for_group(group_id)[0]; + assert_eq!(result_subgroup_id, subgroup_id); - assert_eq!(result_forwarding_preference, forwarding_preference); + let result_stream_id = subscription + .get_stream_id_for_subgroup(group_id, result_subgroup_id) + .unwrap(); + assert_eq!(result_stream_id, stream_id); } #[tokio::test] - async fn set_downstream_forwarding_preference() { + async fn get_downstream_subscribe_ids_for_client() { + let max_subscribe_id = 10; + let downstream_session_id = 1; + let subscribe_ids: Vec = vec![2, 3, 4]; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + for subscribe_id in subscribe_ids.iter() { + let _ = pubsub_relation_manager + .setup_subscriber(max_subscribe_id, downstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_downstream_subscription( + downstream_session_id, + *subscribe_id, + 0, + Vec::from(["test".to_string(), "test".to_string()]), + "track_name".to_string(), + 0, + GroupOrder::Ascending, + FilterType::AbsoluteStart, + Some(0), + Some(0), + None, + ) + .await; + } + + let mut result_subscribe_ids = pubsub_relation_manager + .get_downstream_subscribe_ids_for_client(downstream_session_id) + .await + .unwrap(); + + result_subscribe_ids.sort(); + + assert_eq!(result_subscribe_ids, subscribe_ids); + } + + #[tokio::test] + async fn get_downstream_stream_ids_from_group() { let max_subscribe_id = 10; let downstream_session_id = 1; let subscribe_id = 0; @@ -2667,8 +3583,9 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; - let forwarding_preference = ForwardingPreference::Subgroup; + let group_id = 2; + let subgroup_ids: Vec = vec![3, 4, 5]; + let stream_ids: Vec = vec![6, 7, 8]; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2683,36 +3600,49 @@ mod success { downstream_session_id, subscribe_id, track_alias, - track_namespace.clone(), - track_name.clone(), + track_namespace, + track_name, subscriber_priority, group_order, filter_type, start_group, start_object, end_group, - end_object, - ) - .await; - - let result = pubsub_relation_manager - .set_downstream_forwarding_preference( - downstream_session_id, - subscribe_id, - forwarding_preference.clone(), ) .await; - assert!(result.is_ok()); - // Assert that the forwarding preference is set - let (_, producers, _) = - test_helper_fn::get_node_and_relation_clone(&pubsub_relation_manager).await; - let producer = producers.get(&downstream_session_id).unwrap(); - let subscription = producer.get_subscription(subscribe_id).unwrap().unwrap(); + for i in 0..stream_ids.len() { + let _ = pubsub_relation_manager + .set_downstream_stream_id( + downstream_session_id, + subscribe_id, + group_id, + subgroup_ids[i], + stream_ids[i], + ) + .await; + } - let result_forwarding_preference = subscription.get_forwarding_preference().unwrap(); + let result_subgroup_ids = pubsub_relation_manager + .get_downstream_subgroup_ids_for_group(downstream_session_id, subscribe_id, group_id) + .await + .unwrap(); + assert_eq!(result_subgroup_ids, subgroup_ids); + + for i in 0..subgroup_ids.len() { + let result_stream_id = pubsub_relation_manager + .get_downstream_stream_id_for_subgroup( + downstream_session_id, + subscribe_id, + group_id, + result_subgroup_ids[i], + ) + .await + .unwrap() + .unwrap(); - assert_eq!(result_forwarding_preference, forwarding_preference); + assert_eq!(result_stream_id, stream_ids[i]); + } } #[tokio::test] @@ -2730,7 +3660,6 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2756,7 +3685,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -2778,7 +3706,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -2830,7 +3757,6 @@ mod success { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; // Start track management thread let (track_tx, mut track_rx) = mpsc::channel::(1024); @@ -2856,7 +3782,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await .unwrap(); @@ -2878,7 +3803,6 @@ mod success { start_group, start_object, end_group, - end_object, ) .await; let _ = pubsub_relation_manager @@ -2917,7 +3841,7 @@ mod failure { commands::PubSubRelationCommand, manager::pubsub_relation_manager, wrapper::PubSubRelationManagerWrapper, }; - use moqt_core::messages::control_messages::subscribe::{FilterType, GroupOrder}; + use moqt_core::messages::control_messages::{group_order::GroupOrder, subscribe::FilterType}; use moqt_core::pubsub_relation_manager_repository::PubSubRelationManagerRepository; use tokio::sync::mpsc; @@ -3105,6 +4029,24 @@ mod failure { assert!(result.is_err()); } + #[tokio::test] + async fn get_upstream_subscribe_id_by_track_alias_publisher_not_found() { + let track_alias = 0; + let invalid_upstream_session_id = 1; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + + let result = pubsub_relation_manager + .get_upstream_subscribe_id_by_track_alias(invalid_upstream_session_id, track_alias) + .await; + + assert!(result.is_err()); + } + #[tokio::test] async fn set_downstream_subscription_subscriber_not_found() { let max_subscribe_id = 10; @@ -3119,7 +4061,6 @@ mod failure { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; let invalid_downstream_session_id = 2; // Start track management thread @@ -3145,7 +4086,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; @@ -3164,7 +4104,6 @@ mod failure { let start_group = Some(0); let start_object = Some(0); let end_group = None; - let end_object = None; let invalid_upstream_session_id = 2; // Start track management thread @@ -3188,7 +4127,6 @@ mod failure { start_group, start_object, end_group, - end_object, ) .await; diff --git a/moqt-server/src/modules/send_stream_dispatcher.rs b/moqt-server/src/modules/send_stream_dispatcher.rs deleted file mode 100644 index fd3001ea..00000000 --- a/moqt-server/src/modules/send_stream_dispatcher.rs +++ /dev/null @@ -1,157 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use moqt_core::{ - constants::StreamDirection, messages::moqt_payload::MOQTPayload, SendStreamDispatcherRepository, -}; -use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{mpsc, oneshot}; -type SenderToSendStreamThread = mpsc::Sender>>; - -#[derive(Debug)] -pub(crate) enum SendStreamDispatchCommand { - Set { - session_id: usize, - stream_direction: StreamDirection, - sender: SenderToSendStreamThread, - }, - List { - stream_direction: StreamDirection, - exclude_session_id: Option, // Currently, exclude_session_id is only used in broadcast for List - resp: oneshot::Sender>, - }, - Get { - session_id: usize, - stream_direction: StreamDirection, - resp: oneshot::Sender>, - }, - Delete { - session_id: usize, - }, -} - -pub(crate) async fn send_stream_dispatcher(rx: &mut mpsc::Receiver) { - tracing::trace!("send_stream_dispatcher start"); - // { - // "${session_id}" : { - // "StreamDirection::Uni" : tx, - // "StreamDirection::Bi" : tx, - // } - // } - let mut dispatcher = - HashMap::>::new(); - - while let Some(cmd) = rx.recv().await { - tracing::debug!("command received: {:#?}", cmd); - match cmd { - SendStreamDispatchCommand::Set { - session_id, - stream_direction, - sender, - } => { - let inner_map = dispatcher.entry(session_id).or_default(); - inner_map.insert(stream_direction, sender); - tracing::debug!("set: {:?} of {:?}", stream_direction, session_id); - } - SendStreamDispatchCommand::List { - stream_direction, - exclude_session_id, - resp, - } => { - let mut senders = Vec::new(); - for (session_id, inner_map) in &dispatcher { - if let Some(exclude_session_id) = exclude_session_id { - if *session_id == exclude_session_id { - continue; - } - } - if let Some(sender) = inner_map.get(&stream_direction) { - senders.push(sender.clone()); - } - } - let _ = resp.send(senders); - } - SendStreamDispatchCommand::Get { - session_id, - stream_direction, - resp, - } => { - let sender = dispatcher - .get(&session_id) - .and_then(|inner_map| inner_map.get(&stream_direction)) - .cloned(); - tracing::debug!("get: {:?}", sender); - let _ = resp.send(sender); - } - SendStreamDispatchCommand::Delete { session_id } => { - dispatcher.remove(&session_id); - tracing::debug!("delete: {:?}", session_id); - } - } - } - - tracing::trace!("send_stream_dispatcher end"); -} - -#[derive(Clone)] -pub(crate) struct SendStreamDispatcher { - tx: mpsc::Sender, -} - -impl SendStreamDispatcher { - pub fn new(tx: mpsc::Sender) -> Self { - Self { tx } - } - - // Used for testing in unsubscribe_handler - #[allow(dead_code)] - pub fn get_tx(&self) -> mpsc::Sender { - self.tx.clone() - } -} - -#[async_trait] -impl SendStreamDispatcherRepository for SendStreamDispatcher { - async fn broadcast_message_to_send_stream_threads( - &self, - session_id: Option, - message: Box, - ) -> Result<()> { - let (resp_tx, resp_rx) = oneshot::channel::>(); - let cmd = SendStreamDispatchCommand::List { - stream_direction: StreamDirection::Bi, - exclude_session_id: session_id, - resp: resp_tx, - }; - self.tx.send(cmd).await.unwrap(); - - let senders = resp_rx.await?; - let message_arc = Arc::new(message); - for sender in senders { - let message_arc_clone = Arc::clone(&message_arc); - let _ = sender.send(message_arc_clone).await; - } - Ok(()) - } - async fn transfer_message_to_send_stream_thread( - &self, - session_id: usize, - message: Box, - stream_direction: StreamDirection, - ) -> Result<()> { - let (resp_tx, resp_rx) = oneshot::channel::>(); - - let cmd = SendStreamDispatchCommand::Get { - session_id, - stream_direction, - resp: resp_tx, - }; - self.tx.send(cmd).await.unwrap(); - - let sender = resp_rx - .await? - .ok_or_else(|| anyhow::anyhow!("sender not found"))?; - let message_arc = Arc::new(message); - let _ = sender.send(message_arc).await; - Ok(()) - } -} diff --git a/moqt-server/src/modules/server_processes/control_stream/handler.rs b/moqt-server/src/modules/server_processes/control_stream/handler.rs index 34b6ee0b..ec7348b6 100644 --- a/moqt-server/src/modules/server_processes/control_stream/handler.rs +++ b/moqt-server/src/modules/server_processes/control_stream/handler.rs @@ -1,10 +1,10 @@ use super::bi_stream::BiStream; use crate::modules::{ buffer_manager::{request_buffer, BufferCommand}, + control_message_dispatcher::ControlMessageDispatcher, message_handlers::control_message::{control_message_handler, MessageProcessResult}, moqt_client::MOQTClient, pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, - send_stream_dispatcher::SendStreamDispatcher, }; use anyhow::Result; use bytes::BytesMut; @@ -27,7 +27,8 @@ pub(crate) async fn handle_control_stream( let mut pubsub_relation_manager = PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); - let mut send_stream_dispatcher = SendStreamDispatcher::new(senders.send_stream_tx().clone()); + let mut control_message_dispatcher = + ControlMessageDispatcher::new(senders.control_message_dispatch_tx().clone()); let mut object_cache_storage = crate::modules::object_cache_storage::wrapper::ObjectCacheStorageWrapper::new( @@ -57,7 +58,7 @@ pub(crate) async fn handle_control_stream( &mut client, senders.start_forwarder_txes().clone(), &mut pubsub_relation_manager, - &mut send_stream_dispatcher, + &mut control_message_dispatcher, &mut object_cache_storage, ) .await; diff --git a/moqt-server/src/modules/server_processes/control_stream/sender.rs b/moqt-server/src/modules/server_processes/control_stream/sender.rs index 5c45ce88..d2c607bb 100644 --- a/moqt-server/src/modules/server_processes/control_stream/sender.rs +++ b/moqt-server/src/modules/server_processes/control_stream/sender.rs @@ -4,9 +4,8 @@ use moqt_core::{ messages::{ control_messages::{ announce::Announce, announce_ok::AnnounceOk, subscribe::Subscribe, - subscribe_error::SubscribeError, subscribe_namespace::SubscribeNamespace, - subscribe_namespace_ok::SubscribeNamespaceOk, subscribe_ok::SubscribeOk, - unsubscribe::Unsubscribe, + subscribe_announces::SubscribeAnnounces, subscribe_announces_ok::SubscribeAnnouncesOk, + subscribe_error::SubscribeError, subscribe_ok::SubscribeOk, unsubscribe::Unsubscribe, }, moqt_payload::MOQTPayload, }, @@ -59,27 +58,27 @@ pub(crate) async fn send_control_stream( tracing::info!("Relayed Message Type: {:?}", ControlMessageType::AnnounceOk); } else if message .as_any() - .downcast_ref::() + .downcast_ref::() .is_some() { message_buf.extend(write_variable_integer(u8::from( - ControlMessageType::SubscribeNamespace, + ControlMessageType::SubscribeAnnounces, ) as u64)); tracing::info!( "Relayed Message Type: {:?}", - ControlMessageType::SubscribeNamespace + ControlMessageType::SubscribeAnnounces ); } else if message .as_any() - .downcast_ref::() + .downcast_ref::() .is_some() { message_buf.extend(write_variable_integer(u8::from( - ControlMessageType::SubscribeNamespaceOk, + ControlMessageType::SubscribeAnnouncesOk, ) as u64)); tracing::info!( "Relayed Message Type: {:?}", - ControlMessageType::SubscribeNamespaceOk + ControlMessageType::SubscribeAnnouncesOk ); } else if message.as_any().downcast_ref::().is_some() { message_buf.extend(write_variable_integer( diff --git a/moqt-server/src/modules/server_processes/data_streams.rs b/moqt-server/src/modules/server_processes/data_streams.rs index 7ba1af2c..f710c66d 100644 --- a/moqt-server/src/modules/server_processes/data_streams.rs +++ b/moqt-server/src/modules/server_processes/data_streams.rs @@ -1,2 +1,2 @@ pub(crate) mod datagram; -pub(crate) mod stream; +pub(crate) mod subgroup_stream; diff --git a/moqt-server/src/modules/server_processes/data_streams/datagram/forwarder.rs b/moqt-server/src/modules/server_processes/data_streams/datagram/forwarder.rs index 2cef47a3..3689f75a 100644 --- a/moqt-server/src/modules/server_processes/data_streams/datagram/forwarder.rs +++ b/moqt-server/src/modules/server_processes/data_streams/datagram/forwarder.rs @@ -11,9 +11,11 @@ use moqt_core::{ data_stream_type::DataStreamType, messages::{ control_messages::subscribe::FilterType, - data_streams::{datagram, object_status::ObjectStatus, DataStreams}, + data_streams::{ + datagram, datagram_status, object_status::ObjectStatus, DataStreams, DatagramObject, + }, }, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, + models::{range::ObjectRange, tracks::ForwardingPreference}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, variable_integer::write_variable_integer, }; @@ -26,8 +28,10 @@ pub(crate) struct DatagramObjectForwarder { session: Arc, senders: Arc, downstream_subscribe_id: u64, - downstream_subscription: Subscription, + downstream_track_alias: u64, cache_key: CacheKey, + filter_type: FilterType, + requested_object_range: ObjectRange, sleep_time: Duration, } @@ -44,8 +48,18 @@ impl DatagramObjectForwarder { let downstream_session_id = session.stable_id(); - let downstream_subscription = pubsub_relation_manager - .get_downstream_subscription_by_ids(downstream_session_id, downstream_subscribe_id) + let downstream_track_alias = pubsub_relation_manager + .get_downstream_track_alias(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + let filter_type = pubsub_relation_manager + .get_downstream_filter_type(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + let requested_object_range = pubsub_relation_manager + .get_downstream_requested_object_range(downstream_session_id, downstream_subscribe_id) .await? .unwrap(); @@ -60,8 +74,10 @@ impl DatagramObjectForwarder { session, senders, downstream_subscribe_id, - downstream_subscription, + downstream_track_alias, cache_key, + filter_type, + requested_object_range, sleep_time, }; @@ -171,31 +187,37 @@ impl DatagramObjectForwarder { ) -> Result<(Option, bool)> { // Do loop until get an object from the cache storage loop { - let (cache_id, datagram_object) = - match self.try_get_object(object_cache_storage, cache_id).await? { - Some((id, object)) => (id, object), - None => { - // If there is no object in the cache storage, sleep for a while and try again - thread::sleep(self.sleep_time); - continue; - } - }; - - let message_buf = self.packetize(&datagram_object).await?; + let (cache_id, upstream_object) = match self + .try_get_upstream_object(object_cache_storage, cache_id) + .await? + { + Some((id, object)) => (id, object), + None => { + // If there is no object in the cache storage, sleep for a while and try again + thread::sleep(self.sleep_time); + continue; + } + }; + + let downstream_object = self.generate_downstream_object(&upstream_object); + + let message_buf = self.packetize(&downstream_object).await?; self.send(message_buf).await?; - let is_end = self.is_subscription_ended(&datagram_object) - || self.is_data_stream_ended(&datagram_object); + let mut is_end = false; + if let DatagramObject::ObjectDatagramStatus(object) = &downstream_object { + is_end = self.is_data_stream_ended(object); + } return Ok((Some(cache_id), is_end)); } } - async fn try_get_object( + async fn try_get_upstream_object( &self, object_cache_storage: &mut ObjectCacheStorageWrapper, cache_id: Option, - ) -> Result> { + ) -> Result> { let cache = match cache_id { // Try to get the first object according to Filter Type None => self.try_get_first_object(object_cache_storage).await?, @@ -215,10 +237,9 @@ impl DatagramObjectForwarder { async fn try_get_first_object( &self, object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result> { - let filter_type = self.downstream_subscription.get_filter_type(); - - match filter_type { + ) -> Result> { + match self.filter_type { + // TODO: Remove LatestGroup since it is not exist in the draft-10 FilterType::LatestGroup => { object_cache_storage .get_latest_datagram_group(&self.cache_key) @@ -230,14 +251,11 @@ impl DatagramObjectForwarder { .await } FilterType::AbsoluteStart | FilterType::AbsoluteRange => { - let (start_group, start_object) = self.downstream_subscription.get_absolute_start(); + let start_group = self.requested_object_range.start_group().unwrap(); + let start_object = self.requested_object_range.start_object().unwrap(); object_cache_storage - .get_absolute_datagram_object( - &self.cache_key, - start_group.unwrap(), - start_object.unwrap(), - ) + .get_absolute_datagram_object(&self.cache_key, start_group, start_object) .await } } @@ -247,13 +265,71 @@ impl DatagramObjectForwarder { &self, object_cache_storage: &mut ObjectCacheStorageWrapper, object_cache_id: usize, - ) -> Result> { + ) -> Result> { object_cache_storage .get_next_datagram_object(&self.cache_key, object_cache_id) .await } - async fn packetize(&mut self, datagram_object: &datagram::Object) -> Result { + fn generate_downstream_object(&self, upstream_object: &DatagramObject) -> DatagramObject { + match upstream_object { + DatagramObject::ObjectDatagram(object) => { + let object_datagram = self.generate_downstream_object_datagram(object); + DatagramObject::ObjectDatagram(object_datagram) + } + DatagramObject::ObjectDatagramStatus(object) => { + let object_datagram_status = + self.generate_downstream_object_datagram_status(object); + DatagramObject::ObjectDatagramStatus(object_datagram_status) + } + } + } + + fn generate_downstream_object_datagram( + &self, + upstream_object: &datagram::Object, + ) -> datagram::Object { + let extension_headers = upstream_object.extension_headers().clone(); + datagram::Object::new( + self.downstream_track_alias, // Replace with downstream_track_alias + upstream_object.group_id(), + upstream_object.object_id(), + upstream_object.publisher_priority(), + extension_headers, + upstream_object.object_payload(), + ) + .unwrap() + } + + fn generate_downstream_object_datagram_status( + &self, + upstream_object: &datagram_status::Object, + ) -> datagram_status::Object { + let extension_headers = upstream_object.extension_headers().clone(); + datagram_status::Object::new( + self.downstream_track_alias, // Replace with downstream_track_alias + upstream_object.group_id(), + upstream_object.object_id(), + upstream_object.publisher_priority(), + extension_headers, + upstream_object.object_status(), + ) + .unwrap() + } + + async fn packetize(&mut self, downstream_object: &DatagramObject) -> Result { + match downstream_object { + DatagramObject::ObjectDatagram(object) => self.packetize_object_datagram(object).await, + DatagramObject::ObjectDatagramStatus(object) => { + self.packetize_object_datagram_status(object).await + } + } + } + + async fn packetize_object_datagram( + &mut self, + datagram_object: &datagram::Object, + ) -> Result { let mut buf = BytesMut::new(); datagram_object.packetize(&mut buf); @@ -266,6 +342,22 @@ impl DatagramObjectForwarder { Ok(message_buf) } + async fn packetize_object_datagram_status( + &mut self, + datagram_object: &datagram_status::Object, + ) -> Result { + let mut buf = BytesMut::new(); + datagram_object.packetize(&mut buf); + + let mut message_buf = BytesMut::with_capacity(buf.len()); + message_buf.extend(write_variable_integer( + u8::from(DataStreamType::ObjectDatagramStatus) as u64, + )); + message_buf.extend(buf); + + Ok(message_buf) + } + async fn send(&mut self, message_buf: BytesMut) -> Result<()> { if let Err(e) = self.session.send_datagram(&message_buf) { tracing::warn!("Failed to send datagram: {:?}", e); @@ -275,21 +367,14 @@ impl DatagramObjectForwarder { Ok(()) } - fn is_subscription_ended(&self, datagram_object: &datagram::Object) -> bool { - let group_id = datagram_object.group_id(); - let object_id = datagram_object.object_id(); - - self.downstream_subscription.is_end(group_id, object_id) - } - // This function is implemented according to the following sentence in draft. - // A relay MAY treat receipt of EndOfGroup, EndOfSubgroup, GroupDoesNotExist, or + // A relay MAY treat receipt of EndOfGroup, EndOfTrack, GroupDoesNotExist, or // EndOfTrack objects as a signal to close corresponding streams even if the FIN // has not arrived, as further objects on the stream would be a protocol violation. - fn is_data_stream_ended(&self, datagram_object: &datagram::Object) -> bool { + fn is_data_stream_ended(&self, object: &datagram_status::Object) -> bool { matches!( - datagram_object.object_status(), - Some(ObjectStatus::EndOfTrackAndGroup) + object.object_status(), + ObjectStatus::EndOfTrack | ObjectStatus::EndOfTrackAndGroup ) } } diff --git a/moqt-server/src/modules/server_processes/data_streams/datagram/receiver.rs b/moqt-server/src/modules/server_processes/data_streams/datagram/receiver.rs index dc5052a7..9a80d162 100644 --- a/moqt-server/src/modules/server_processes/data_streams/datagram/receiver.rs +++ b/moqt-server/src/modules/server_processes/data_streams/datagram/receiver.rs @@ -13,7 +13,7 @@ use anyhow::Result; use bytes::BytesMut; use moqt_core::{ constants::TerminationErrorCode, data_stream_type::DataStreamType, - messages::data_streams::datagram, models::tracks::ForwardingPreference, + messages::data_streams::DatagramObject, models::tracks::ForwardingPreference, pubsub_relation_manager_repository::PubSubRelationManagerRepository, }; use std::sync::Arc; @@ -72,7 +72,11 @@ impl DatagramObjectReceiver { }; let session_id = self.client.lock().await.id(); - let subscribe_id = object.subscribe_id(); + let track_alias = match &object { + DatagramObject::ObjectDatagram(object) => object.track_alias(), + DatagramObject::ObjectDatagramStatus(object) => object.track_alias(), + }; + let subscribe_id = self.get_subscribe_id(session_id, track_alias).await?; if self .is_first_object(session_id, subscribe_id, object_cache_storage) @@ -101,7 +105,7 @@ impl DatagramObjectReceiver { buf.extend_from_slice(&read_bytes); } - async fn read_object_from_buf(&self) -> Result, TerminationError> { + async fn read_object_from_buf(&self) -> Result, TerminationError> { let result = self.try_read_object_from_buf().await; match result { @@ -122,6 +126,33 @@ impl DatagramObjectReceiver { datagram_object::try_read_object(&mut buf, client).await } + async fn get_subscribe_id( + &self, + session_id: usize, + track_alias: u64, + ) -> Result { + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + match pubsub_relation_manager + .get_upstream_subscribe_id_by_track_alias(session_id, track_alias) + .await + { + Ok(Some(subscribe_id)) => Ok(subscribe_id), + Ok(None) => { + let msg = "Subscribe id is not found".to_string(); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + Err(err) => { + let msg = format!("Fail to get subscribe id: {:?}", err); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } + async fn is_first_object( &self, upstream_session_id: usize, @@ -187,7 +218,7 @@ impl DatagramObjectReceiver { async fn store_object( &self, - datagram_object: datagram::Object, + datagram_object: DatagramObject, upstream_session_id: usize, upstream_subscribe_id: u64, object_cache_storage: &mut ObjectCacheStorageWrapper, diff --git a/moqt-server/src/modules/server_processes/data_streams/stream/forwarder.rs b/moqt-server/src/modules/server_processes/data_streams/stream/forwarder.rs deleted file mode 100644 index 7dba34f5..00000000 --- a/moqt-server/src/modules/server_processes/data_streams/stream/forwarder.rs +++ /dev/null @@ -1,577 +0,0 @@ -use super::uni_stream::UniSendStream; -use crate::{ - modules::{ - buffer_manager::BufferCommand, - message_handlers::{stream_header::StreamHeader, stream_object::StreamObject}, - moqt_client::MOQTClient, - object_cache_storage::{cache::CacheKey, wrapper::ObjectCacheStorageWrapper}, - pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, - server_processes::senders::Senders, - }, - SubgroupStreamId, -}; -use anyhow::{bail, Result}; -use bytes::BytesMut; -use moqt_core::{ - data_stream_type::DataStreamType, - messages::{ - control_messages::subscribe::FilterType, - data_streams::{object_status::ObjectStatus, subgroup_stream, track_stream, DataStreams}, - }, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, - pubsub_relation_manager_repository::PubSubRelationManagerRepository, - variable_integer::write_variable_integer, -}; -use std::{sync::Arc, thread, time::Duration}; -use tokio::sync::Mutex; -use tracing::{self}; - -pub(crate) struct StreamObjectForwarder { - stream: UniSendStream, - senders: Arc, - downstream_subscribe_id: u64, - downstream_subscription: Subscription, - data_stream_type: DataStreamType, - cache_key: CacheKey, - subgroup_stream_id: Option, - sleep_time: Duration, -} - -impl StreamObjectForwarder { - pub(crate) async fn init( - stream: UniSendStream, - downstream_subscribe_id: u64, - client: Arc>, - data_stream_type: DataStreamType, - subgroup_stream_id: Option, - ) -> Result { - let senders = client.lock().await.senders(); - let sleep_time = Duration::from_millis(10); - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); - - let downstream_session_id = stream.stable_id(); - - let downstream_subscription = pubsub_relation_manager - .get_downstream_subscription_by_ids(downstream_session_id, downstream_subscribe_id) - .await? - .unwrap(); - - // Get the information of the original publisher who has the track being requested - let (upstream_session_id, upstream_subscribe_id) = pubsub_relation_manager - .get_related_publisher(downstream_session_id, downstream_subscribe_id) - .await?; - - let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); - - let stream_object_forwarder = StreamObjectForwarder { - stream, - senders, - downstream_subscribe_id, - downstream_subscription, - data_stream_type, - cache_key, - subgroup_stream_id, - sleep_time, - }; - - Ok(stream_object_forwarder) - } - - pub(crate) async fn start(&mut self) -> Result<()> { - let mut object_cache_storage = - ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); - - let upstream_forwarding_preference = self.get_upstream_forwarding_preference().await?; - self.validate_forwarding_preference(&upstream_forwarding_preference) - .await?; - - let downstream_forwarding_preference = upstream_forwarding_preference.clone(); - self.set_forwarding_preference(downstream_forwarding_preference) - .await?; - - self.forward_header(&mut object_cache_storage).await?; - - self.forward_objects(&mut object_cache_storage).await?; - - Ok(()) - } - - pub(crate) async fn finish(&self) -> Result<()> { - let downstream_session_id = self.stream.stable_id(); - let downstream_stream_id = self.stream.stream_id(); - self.senders - .buffer_tx() - .send(BufferCommand::ReleaseStream { - session_id: downstream_session_id, - stream_id: downstream_stream_id, - }) - .await?; - - tracing::info!("StreamObjectForwarder finished"); - - Ok(()) - } - - async fn get_upstream_forwarding_preference(&self) -> Result> { - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); - - let upstream_session_id = self.cache_key.session_id(); - let upstream_subscribe_id = self.cache_key.subscribe_id(); - - pubsub_relation_manager - .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) - .await - } - - async fn validate_forwarding_preference( - &self, - upstream_forwarding_preference: &Option, - ) -> Result<()> { - match upstream_forwarding_preference { - Some(ForwardingPreference::Track) => self.check_data_stream_type_track().await?, - Some(ForwardingPreference::Subgroup) => self.check_data_stream_type_subgroup().await?, - _ => { - bail!("Forwarding preference is not Stream"); - } - } - - Ok(()) - } - - async fn check_data_stream_type_track(&self) -> Result<()> { - if self.data_stream_type != DataStreamType::StreamHeaderTrack { - bail!( - "uni send stream's data stream type is wrong (expected Track, but got {:?})", - self.data_stream_type - ); - } - - Ok(()) - } - - async fn check_data_stream_type_subgroup(&self) -> Result<()> { - if self.data_stream_type != DataStreamType::StreamHeaderSubgroup { - bail!( - "uni send stream's data stream type is wrong (expected Subgroup, but got {:?})", - self.data_stream_type - ); - } - - Ok(()) - } - - async fn set_forwarding_preference( - &self, - downstream_forwarding_preference: Option, - ) -> Result<()> { - let forwarding_preference = downstream_forwarding_preference.unwrap(); - let downstream_session_id = self.stream.stable_id(); - let downstream_subscribe_id = self.downstream_subscribe_id; - - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); - - pubsub_relation_manager - .set_downstream_forwarding_preference( - downstream_session_id, - downstream_subscribe_id, - forwarding_preference, - ) - .await?; - - Ok(()) - } - - async fn forward_header( - &mut self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<()> { - let stream_header = self.get_header(object_cache_storage).await?; - - let message_buf = self.packetize_header(&stream_header).await?; - - self.send(message_buf).await?; - - Ok(()) - } - - async fn get_header( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result { - match self.data_stream_type { - DataStreamType::StreamHeaderTrack => { - let track_stream_header = object_cache_storage - .get_track_stream_header(&self.cache_key) - .await?; - - let header = StreamHeader::Track(track_stream_header); - - Ok(header) - } - DataStreamType::StreamHeaderSubgroup => { - let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); - let subgroup_stream_header = object_cache_storage - .get_subgroup_stream_header(&self.cache_key, group_id, subgroup_id) - .await?; - - let header = StreamHeader::Subgroup(subgroup_stream_header); - - Ok(header) - } - _ => { - let msg = "data stream type is not StreamHeader"; - bail!(msg) - } - } - } - - async fn forward_objects( - &mut self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<()> { - let mut object_cache_id = None; - let mut is_end = false; - - while !is_end { - (object_cache_id, is_end) = self - .forward_object(object_cache_storage, object_cache_id) - .await?; - } - - Ok(()) - } - - async fn forward_object( - &mut self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - cache_id: Option, - ) -> Result<(Option, bool)> { - // Do loop until get an object from the cache storage - loop { - let (cache_id, stream_object) = - match self.try_get_object(object_cache_storage, cache_id).await? { - Some((id, object)) => (id, object), - None => { - // If there is no object in the cache storage, sleep for a while and try again - thread::sleep(self.sleep_time); - continue; - } - }; - - let message_buf = self.packetize_object(&stream_object).await?; - self.send(message_buf).await?; - - let is_end = self.is_subscription_ended(&stream_object) - || self.is_data_stream_ended(&stream_object); - - return Ok((Some(cache_id), is_end)); - } - } - - async fn try_get_object( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - cache_id: Option, - ) -> Result> { - match cache_id { - // Try to get the first object according to Filter Type - None => self.try_get_first_object(object_cache_storage).await, - Some(cache_id) => { - // Try to get the subsequent object with cache_id - self.try_get_subsequent_object(object_cache_storage, cache_id) - .await - } - } - } - - async fn try_get_first_object( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result> { - let stream_object_with_cache_id = match self.data_stream_type { - DataStreamType::StreamHeaderTrack => { - let track_stream_object_with_cache_id = self - .try_get_first_track_stream_object(object_cache_storage) - .await?; - - if track_stream_object_with_cache_id.is_none() { - None - } else { - let (cache_id, object) = track_stream_object_with_cache_id.unwrap(); - let stream_object = StreamObject::Track(object); - - Some((cache_id, stream_object)) - } - } - DataStreamType::StreamHeaderSubgroup => { - let subgroup_stream_object_with_cache_id = self - .try_get_first_subgroup_stream_object(object_cache_storage) - .await?; - - if subgroup_stream_object_with_cache_id.is_none() { - None - } else { - let (cache_id, object) = subgroup_stream_object_with_cache_id.unwrap(); - let stream_object = StreamObject::Subgroup(object); - - Some((cache_id, stream_object)) - } - } - _ => { - let msg = "data stream type is not StreamHeader"; - bail!(msg) - } - }; - - Ok(stream_object_with_cache_id) - } - - async fn try_get_first_track_stream_object( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result> { - let filter_type = self.downstream_subscription.get_filter_type(); - - match filter_type { - FilterType::LatestGroup => { - object_cache_storage - .get_latest_track_stream_group(&self.cache_key) - .await - } - FilterType::LatestObject => { - object_cache_storage - .get_latest_track_stream_object(&self.cache_key) - .await - } - FilterType::AbsoluteStart | FilterType::AbsoluteRange => { - let (start_group, start_object) = self.downstream_subscription.get_absolute_start(); - - object_cache_storage - .get_absolute_track_stream_object( - &self.cache_key, - start_group.unwrap(), - start_object.unwrap(), - ) - .await - } - } - } - - async fn try_get_first_subgroup_stream_object( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result> { - let filter_type = self.downstream_subscription.get_filter_type(); - let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); - - match filter_type { - FilterType::LatestGroup => { - // Try to obtain the first object in the subgroup stream specified by the arguments. - // This operation is the same on the first stream and on subsequent streams. - object_cache_storage - .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) - .await - } - // Currently not supported - FilterType::LatestObject => { - // Try to obtain the first object in the subgroup stream specified by the arguments. - // TODO: If it's on the first subgroup stream, it should get the latest object. - // To distinguish the first stream, we need to modify downstream subscription. - // (e.g. Implementation of FilterType::AbsoluteStart | FilterType::AbsoluteRange) - object_cache_storage - .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) - .await - } - FilterType::AbsoluteStart | FilterType::AbsoluteRange => { - let (start_group, start_object) = self.downstream_subscription.get_absolute_start(); - let start_group = start_group.unwrap(); - let start_object = start_object.unwrap(); - - if group_id == start_group { - object_cache_storage - .get_absolute_subgroup_stream_object( - &self.cache_key, - group_id, - subgroup_id, - start_object, - ) - .await - } else { - object_cache_storage - .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) - .await - } - } - } - } - - async fn try_get_subsequent_object( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - object_cache_id: usize, - ) -> Result> { - let stream_object_with_cache_id = match self.data_stream_type { - DataStreamType::StreamHeaderTrack => { - let track_stream_object_with_cache_id = object_cache_storage - .get_next_track_stream_object(&self.cache_key, object_cache_id) - .await?; - - if track_stream_object_with_cache_id.is_none() { - None - } else { - let (cache_id, object) = track_stream_object_with_cache_id.unwrap(); - let stream_object = StreamObject::Track(object); - - Some((cache_id, stream_object)) - } - } - DataStreamType::StreamHeaderSubgroup => { - let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); - let subgroup_stream_object_with_cache_id = object_cache_storage - .get_next_subgroup_stream_object( - &self.cache_key, - group_id, - subgroup_id, - object_cache_id, - ) - .await?; - - if subgroup_stream_object_with_cache_id.is_none() { - None - } else { - let (cache_id, object) = subgroup_stream_object_with_cache_id.unwrap(); - let stream_object = StreamObject::Subgroup(object); - - Some((cache_id, stream_object)) - } - } - _ => { - let msg = "data stream type is not StreamHeader"; - bail!(msg) - } - }; - - Ok(stream_object_with_cache_id) - } - - async fn packetize_header(&mut self, stream_header: &StreamHeader) -> Result { - let message_buf = match stream_header { - StreamHeader::Track(header) => self.packetize_track_header(header), - StreamHeader::Subgroup(header) => self.packetize_subgroup_header(header), - }; - - Ok(message_buf) - } - - fn packetize_track_header(&self, header: &track_stream::Header) -> BytesMut { - let mut buf = BytesMut::new(); - let downstream_subscribe_id = self.downstream_subscribe_id; - let downstream_track_alias = self.downstream_subscription.get_track_alias(); - - let header = track_stream::Header::new( - downstream_subscribe_id, - downstream_track_alias, - header.publisher_priority(), - ) - .unwrap(); - - header.packetize(&mut buf); - - let mut message_buf = BytesMut::with_capacity(buf.len() + 8); - message_buf.extend(write_variable_integer( - u8::from(DataStreamType::StreamHeaderTrack) as u64, - )); - message_buf.extend(buf); - - message_buf - } - - fn packetize_subgroup_header(&self, header: &subgroup_stream::Header) -> BytesMut { - let mut buf = BytesMut::new(); - let downstream_subscribe_id = self.downstream_subscribe_id; - let downstream_track_alias = self.downstream_subscription.get_track_alias(); - - let header = subgroup_stream::Header::new( - downstream_subscribe_id, - downstream_track_alias, - header.group_id(), - header.subgroup_id(), - header.publisher_priority(), - ) - .unwrap(); - - header.packetize(&mut buf); - - let mut message_buf = BytesMut::with_capacity(buf.len() + 8); - message_buf.extend(write_variable_integer( - u8::from(DataStreamType::StreamHeaderSubgroup) as u64, - )); - message_buf.extend(buf); - - message_buf - } - - async fn packetize_object(&mut self, stream_object: &StreamObject) -> Result { - let mut buf = BytesMut::new(); - - match stream_object { - StreamObject::Track(track_object) => track_object.packetize(&mut buf), - StreamObject::Subgroup(subgroup_object) => subgroup_object.packetize(&mut buf), - } - - let mut message_buf = BytesMut::with_capacity(buf.len()); - message_buf.extend(buf); - - Ok(message_buf) - } - - async fn send(&mut self, message_buf: BytesMut) -> Result<()> { - if let Err(e) = self.stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); - } - - Ok(()) - } - - fn is_subscription_ended(&self, stream_object: &StreamObject) -> bool { - let (group_id, object_id) = match stream_object { - StreamObject::Track(track_stream_object) => ( - track_stream_object.group_id(), - track_stream_object.object_id(), - ), - StreamObject::Subgroup(subgroup_stream_object) => ( - self.subgroup_stream_id.unwrap().0, - subgroup_stream_object.object_id(), - ), - }; - - self.downstream_subscription.is_end(group_id, object_id) - } - - // This function is implemented according to the following sentence in draft. - // A relay MAY treat receipt of EndOfGroup, EndOfSubgroup, GroupDoesNotExist, or - // EndOfTrack objects as a signal to close corresponding streams even if the FIN - // has not arrived, as further objects on the stream would be a protocol violation. - fn is_data_stream_ended(&self, stream_object: &StreamObject) -> bool { - match stream_object { - StreamObject::Track(track_stream_object) => { - matches!( - track_stream_object.object_status(), - Some(ObjectStatus::EndOfTrackAndGroup) - ) - } - StreamObject::Subgroup(subgroup_stream_object) => { - matches!( - subgroup_stream_object.object_status(), - Some(ObjectStatus::EndOfSubgroup) - | Some(ObjectStatus::EndOfGroup) - | Some(ObjectStatus::EndOfTrackAndGroup) - ) - } - } - } -} diff --git a/moqt-server/src/modules/server_processes/data_streams/stream/receiver.rs b/moqt-server/src/modules/server_processes/data_streams/stream/receiver.rs deleted file mode 100644 index 1d9c3788..00000000 --- a/moqt-server/src/modules/server_processes/data_streams/stream/receiver.rs +++ /dev/null @@ -1,597 +0,0 @@ -use super::uni_stream::UniRecvStream; -use crate::{ - modules::{ - buffer_manager::{request_buffer, BufferCommand}, - message_handlers::{ - stream_header::{self, StreamHeader, StreamHeaderProcessResult}, - stream_object::{self, StreamObject, StreamObjectProcessResult}, - }, - moqt_client::MOQTClient, - object_cache_storage::{ - cache::{CacheKey, SubgroupStreamId}, - wrapper::ObjectCacheStorageWrapper, - }, - pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, - server_processes::senders::Senders, - }, - TerminationError, -}; -use anyhow::Result; -use bytes::BytesMut; -use moqt_core::{ - constants::TerminationErrorCode, - data_stream_type::DataStreamType, - messages::data_streams::object_status::ObjectStatus, - models::{subscriptions::Subscription, tracks::ForwardingPreference}, - pubsub_relation_manager_repository::PubSubRelationManagerRepository, -}; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::{self}; - -pub(crate) struct StreamObjectReceiver { - stream: UniRecvStream, - buf: Arc>, - senders: Arc, - client: Arc>, - duration: u64, - subscribe_id: Option, - data_stream_type: Option, - upstream_subscription: Option, - subgroup_stream_id: Option, -} - -impl StreamObjectReceiver { - pub(crate) async fn init(stream: UniRecvStream, client: Arc>) -> Self { - let senders = client.lock().await.senders(); - let stable_id = stream.stable_id(); - let stream_id = stream.stream_id(); - let buf = request_buffer(senders.buffer_tx().clone(), stable_id, stream_id).await; - // TODO: Set the accurate duration - let duration = 100000; - - StreamObjectReceiver { - stream, - buf, - senders, - client, - duration, - subscribe_id: None, - data_stream_type: None, - upstream_subscription: None, - subgroup_stream_id: None, - } - } - - pub(crate) async fn start(&mut self) -> Result<(), TerminationError> { - let mut object_cache_storage = - ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); - - let mut is_end = false; - let session_id = self.client.lock().await.id(); - - while !is_end { - let read_bytes = self.read_stream().await?; - self.add_to_buf(read_bytes).await; - - if !self.has_header() { - self.receive_header(session_id, &mut object_cache_storage) - .await?; - - // If the header has not been received, continue to receive the header. - if !self.has_header() { - continue; - } - } - - is_end = self.receive_objects(&mut object_cache_storage).await?; - } - - Ok(()) - } - - pub(crate) async fn finish(&self) -> Result<()> { - self.senders - .buffer_tx() - .send(BufferCommand::ReleaseStream { - session_id: self.stream.stable_id(), - stream_id: self.stream.stream_id(), - }) - .await?; - - tracing::debug!("StreamObjectReceiver finished"); - - Ok(()) - } - - async fn read_stream(&mut self) -> Result { - let mut buffer = vec![0; 65536].into_boxed_slice(); - - let length: usize = match self.stream.read(&mut buffer).await { - Ok(byte_read) => byte_read.unwrap(), - Err(err) => { - let msg = format!("Failed to read from stream: {:?}", err); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - }; - - Ok(BytesMut::from(&buffer[..length])) - } - - async fn add_to_buf(&mut self, read_buf: BytesMut) { - let mut buf = self.buf.lock().await; - buf.extend_from_slice(&read_buf); - } - - fn has_header(&self) -> bool { - self.upstream_subscription.is_some() - } - - async fn receive_header( - &mut self, - session_id: usize, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(), TerminationError> { - let header = match self.read_header_from_buf().await? { - Some(header) => header, - None => { - return Ok(()); - } - }; - - self.set_subscribe_id(&header).await?; - self.set_data_stream_type(&header).await?; - - let subscribe_id = self.subscribe_id.unwrap(); - let data_stream_type = self.data_stream_type.unwrap(); - - self.set_upstream_forwarding_preference(session_id, subscribe_id, data_stream_type) - .await?; - self.set_upstream_subscription(session_id, subscribe_id) - .await?; - - if let StreamHeader::Subgroup(header) = &header { - self.subgroup_stream_id = Some((header.group_id(), header.subgroup_id())); - } - - self.create_cache_storage(session_id, subscribe_id, header, object_cache_storage) - .await?; - - self.create_forwarders(session_id, subscribe_id).await?; - - Ok(()) - } - - async fn read_header_from_buf(&self) -> Result, TerminationError> { - let result = self.try_read_header_from_buf().await; - - match result { - StreamHeaderProcessResult::Success(stream_header) => Ok(Some(stream_header)), - StreamHeaderProcessResult::Continue => Ok(None), - StreamHeaderProcessResult::Failure(code, reason) => { - let msg = std::format!("stream_header_read failure: {:?}", reason); - Err((code, msg)) - } - } - } - - async fn try_read_header_from_buf(&self) -> StreamHeaderProcessResult { - let mut process_buf = self.buf.lock().await; - let client = self.client.clone(); - - stream_header::try_read_header(&mut process_buf, client).await - } - - async fn set_subscribe_id( - &mut self, - stream_header: &StreamHeader, - ) -> Result<(), TerminationError> { - let subscribe_id = match stream_header { - StreamHeader::Track(header) => header.subscribe_id(), - StreamHeader::Subgroup(header) => header.subscribe_id(), - }; - - self.subscribe_id = Some(subscribe_id); - - Ok(()) - } - - async fn set_data_stream_type( - &mut self, - stream_header: &StreamHeader, - ) -> Result<(), TerminationError> { - let data_stream_type = match stream_header { - StreamHeader::Track(_) => DataStreamType::StreamHeaderTrack, - StreamHeader::Subgroup(_) => DataStreamType::StreamHeaderSubgroup, - }; - - self.data_stream_type = Some(data_stream_type); - - Ok(()) - } - - async fn set_upstream_forwarding_preference( - &self, - upstream_session_id: usize, - upstream_subscribe_id: u64, - data_stream_type: DataStreamType, - ) -> Result<(), TerminationError> { - let forwarding_preference = match data_stream_type { - DataStreamType::StreamHeaderTrack => ForwardingPreference::Track, - DataStreamType::StreamHeaderSubgroup => ForwardingPreference::Subgroup, - _ => { - let msg = "data_stream_type not matched".to_string(); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - }; - - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); - match pubsub_relation_manager - .set_upstream_forwarding_preference( - upstream_session_id, - upstream_subscribe_id, - forwarding_preference, - ) - .await - { - Ok(_) => Ok(()), - Err(err) => { - let msg = format!("Fail to set upstream forwarding preference: {:?}", err); - let code = TerminationErrorCode::InternalError; - - Err((code, msg)) - } - } - } - - async fn set_upstream_subscription( - &mut self, - upstream_session_id: usize, - upstream_subscribe_id: u64, - ) -> Result<(), TerminationError> { - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); - let upstream_subscription = match pubsub_relation_manager - .get_upstream_subscription_by_ids(upstream_session_id, upstream_subscribe_id) - .await - { - Ok(upstream_subscription) => upstream_subscription, - Err(err) => { - let msg = format!("Fail to get upstream subscription: {:?}", err); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - }; - - if upstream_subscription.is_none() { - let msg = "Upstream subscription not found".to_string(); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - - self.upstream_subscription = upstream_subscription; - - Ok(()) - } - - async fn create_cache_storage( - &self, - upstream_session_id: usize, - upstream_subscribe_id: u64, - stream_header: StreamHeader, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(), TerminationError> { - let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); - - let result = match stream_header { - StreamHeader::Track(track_header) => { - object_cache_storage - .create_track_stream_cache(&cache_key, track_header) - .await - } - StreamHeader::Subgroup(subgroup_header) => { - let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); - object_cache_storage - .create_subgroup_stream_cache( - &cache_key, - group_id, - subgroup_id, - subgroup_header, - ) - .await - } - }; - - match result { - Ok(_) => Ok(()), - Err(err) => { - let msg = format!("Fail to create cache storage: {:?}", err); - let code = TerminationErrorCode::InternalError; - - Err((code, msg)) - } - } - } - - async fn create_forwarders( - &self, - upstream_session_id: usize, - upstream_subscribe_id: u64, - ) -> Result<(), TerminationError> { - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); - - let subscribers = match pubsub_relation_manager - .get_related_subscribers(upstream_session_id, upstream_subscribe_id) - .await - { - Ok(subscribers) => subscribers, - Err(err) => { - let msg = format!("Fail to get related subscribers: {:?}", err); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - }; - - for (downstream_session_id, downstream_subscribe_id) in subscribers { - match self - .create_forwarder(downstream_session_id, downstream_subscribe_id) - .await - { - Ok(_) => {} - Err(err) => { - let msg = format!("Fail to create forwarder: {:?}", err); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - } - } - Ok(()) - } - - async fn create_forwarder( - &self, - downstream_session_id: usize, - downstream_subscribe_id: u64, - ) -> Result<()> { - let start_forwarder_txes = self.senders.start_forwarder_txes(); - let data_stream_type = self.data_stream_type.unwrap(); - - let start_forwarder_tx = start_forwarder_txes - .lock() - .await - .get(&downstream_session_id) - .unwrap() - .clone(); - - start_forwarder_tx - .send(( - downstream_subscribe_id, - data_stream_type, - self.subgroup_stream_id, - )) - .await?; - - Ok(()) - } - - async fn receive_objects( - &self, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result { - let session_id = self.client.lock().await.id(); - let subscribe_id = self.subscribe_id.unwrap(); - let mut is_end = false; - - while !is_end { - is_end = match self - .receive_object(session_id, subscribe_id, object_cache_storage) - .await? - { - Some(is_end) => is_end, - None => break, // Return to read stream again since there is no object in the buffer. - }; - } - - Ok(is_end) - } - - async fn receive_object( - &self, - session_id: usize, - subscribe_id: u64, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result, TerminationError> { - let stream_object = match self.read_object_from_buf().await? { - Some(object) => object, - None => { - return Ok(None); - } - }; - - self.store_object( - &stream_object, - session_id, - subscribe_id, - object_cache_storage, - ) - .await?; - - let is_end = - self.is_subscription_ended(&stream_object) || self.is_data_stream_ended(&stream_object); - - Ok(Some(is_end)) - } - - async fn read_object_from_buf(&self) -> Result, TerminationError> { - let result = self.try_read_object_from_buf().await; - - match result { - StreamObjectProcessResult::Success(stream_object) => Ok(Some(stream_object)), - StreamObjectProcessResult::Continue => Ok(None), - StreamObjectProcessResult::Failure(code, reason) => { - let msg = std::format!("stream_object_read failure: {:?}", reason); - Err((code, msg)) - } - } - } - - async fn try_read_object_from_buf(&self) -> StreamObjectProcessResult { - let mut buf = self.buf.lock().await; - let data_stream_type = self.data_stream_type.unwrap(); - - stream_object::try_read_object(&mut buf, data_stream_type).await - } - - async fn store_object( - &self, - stream_object: &StreamObject, - upstream_session_id: usize, - upstream_subscribe_id: u64, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(), TerminationError> { - let data_stream_type = self.data_stream_type.unwrap(); - let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); - - match data_stream_type { - DataStreamType::StreamHeaderTrack => { - self.store_track_stream_object(stream_object, &cache_key, object_cache_storage) - .await?; - } - DataStreamType::StreamHeaderSubgroup => { - self.store_subgroup_stream_object(stream_object, &cache_key, object_cache_storage) - .await?; - } - _ => { - let msg = "data_stream_type not matched".to_string(); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - } - - Ok(()) - } - - async fn store_track_stream_object( - &self, - stream_object: &StreamObject, - cache_key: &CacheKey, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(), TerminationError> { - let track_stream_object = match stream_object { - StreamObject::Track(object) => object, - _ => { - let msg = "StreamObject is not Track".to_string(); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - }; - match object_cache_storage - .set_track_stream_object(cache_key, track_stream_object.clone(), self.duration) - .await - { - Ok(_) => Ok(()), - Err(err) => { - let msg = format!( - "Fail to store track stream object to cache storage: {:?}", - err - ); - let code = TerminationErrorCode::InternalError; - - Err((code, msg)) - } - } - } - - async fn store_subgroup_stream_object( - &self, - stream_object: &StreamObject, - cache_key: &CacheKey, - object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(), TerminationError> { - let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); - let subgroup_stream_object = match stream_object { - StreamObject::Subgroup(object) => object, - _ => { - let msg = "StreamObject is not Subgroup".to_string(); - let code = TerminationErrorCode::InternalError; - - return Err((code, msg)); - } - }; - match object_cache_storage - .set_subgroup_stream_object( - cache_key, - group_id, - subgroup_id, - subgroup_stream_object.clone(), - self.duration, - ) - .await - { - Ok(_) => Ok(()), - Err(err) => { - let msg = format!( - "Fail to store subgroup stream object to cache storage: {:?}", - err - ); - let code = TerminationErrorCode::InternalError; - - Err((code, msg)) - } - } - } - - fn is_subscription_ended(&self, object: &StreamObject) -> bool { - let (group_id, object_id) = match object { - StreamObject::Track(track_stream_object) => ( - track_stream_object.group_id(), - track_stream_object.object_id(), - ), - StreamObject::Subgroup(subgroup_stream_object) => { - let (subgroup_group_id, _) = self.subgroup_stream_id.unwrap(); - (subgroup_group_id, subgroup_stream_object.object_id()) - } - }; - - self.upstream_subscription - .as_ref() - .unwrap() - .is_end(group_id, object_id) - } - - // This function is implemented according to the following sentence in draft. - // A relay MAY treat receipt of EndOfGroup, EndOfSubgroup, GroupDoesNotExist, or - // EndOfTrack objects as a signal to close corresponding streams even if the FIN - // has not arrived, as further objects on the stream would be a protocol violation. - // TODO: Add handling for FIN message - fn is_data_stream_ended(&self, stream_object: &StreamObject) -> bool { - match stream_object { - StreamObject::Track(track_stream_object) => { - matches!( - track_stream_object.object_status(), - Some(ObjectStatus::EndOfTrackAndGroup) - ) - } - StreamObject::Subgroup(subgroup_stream_object) => { - matches!( - subgroup_stream_object.object_status(), - Some(ObjectStatus::EndOfSubgroup) - | Some(ObjectStatus::EndOfGroup) - | Some(ObjectStatus::EndOfTrackAndGroup) - ) - } - } - } -} diff --git a/moqt-server/src/modules/server_processes/data_streams/stream.rs b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream.rs similarity index 100% rename from moqt-server/src/modules/server_processes/data_streams/stream.rs rename to moqt-server/src/modules/server_processes/data_streams/subgroup_stream.rs diff --git a/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/forwarder.rs b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/forwarder.rs new file mode 100644 index 00000000..b7848201 --- /dev/null +++ b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/forwarder.rs @@ -0,0 +1,594 @@ +use super::uni_stream::UniSendStream; +use crate::{ + modules::{ + buffer_manager::BufferCommand, + moqt_client::MOQTClient, + object_cache_storage::{cache::CacheKey, wrapper::ObjectCacheStorageWrapper}, + pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, + server_processes::senders::Senders, + }, + signal_dispatcher::{DataStreamThreadSignal, SignalDispatcher, TerminateReason}, + SubgroupStreamId, +}; +use anyhow::{bail, Ok, Result}; +use bytes::BytesMut; +use moqt_core::{ + data_stream_type::DataStreamType, + messages::{ + control_messages::subscribe::FilterType, + data_streams::{object_status::ObjectStatus, subgroup_stream, DataStreams}, + }, + models::{ + range::{ObjectRange, ObjectStart}, + tracks::ForwardingPreference, + }, + pubsub_relation_manager_repository::PubSubRelationManagerRepository, + variable_integer::write_variable_integer, +}; +use std::{ + sync::{atomic::AtomicBool, atomic::Ordering, Arc}, + thread, + time::Duration, +}; +use tokio::sync::{mpsc, Mutex}; +use tracing::{self}; + +pub(crate) struct SubgroupStreamObjectForwarder { + stream: UniSendStream, + senders: Arc, + downstream_subscribe_id: u64, + downstream_track_alias: u64, + cache_key: CacheKey, + subgroup_stream_id: SubgroupStreamId, + filter_type: FilterType, + is_terminated: Arc, + requested_object_range: ObjectRange, + sleep_time: Duration, +} + +impl SubgroupStreamObjectForwarder { + pub(crate) async fn init( + stream: UniSendStream, + downstream_subscribe_id: u64, + client: Arc>, + subgroup_stream_id: SubgroupStreamId, + mut signal_rx: mpsc::Receiver>, + ) -> Result { + let senders = client.lock().await.senders(); + let sleep_time = Duration::from_millis(10); + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); + + let downstream_session_id = stream.stable_id(); + + let downstream_track_alias = pubsub_relation_manager + .get_downstream_track_alias(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + let filter_type = pubsub_relation_manager + .get_downstream_filter_type(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + let requested_object_range = pubsub_relation_manager + .get_downstream_requested_object_range(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + // Get the information of the original publisher who has the track being requested + let (upstream_session_id, upstream_subscribe_id) = pubsub_relation_manager + .get_related_publisher(downstream_session_id, downstream_subscribe_id) + .await?; + + // Register stream_id to receive signal from other subgroup forwarder threads belong to the same group + let (group_id, subgroup_id) = subgroup_stream_id; + let stream_id = stream.stream_id(); + pubsub_relation_manager + .set_downstream_stream_id( + downstream_session_id, + downstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + ) + .await?; + + // Task to receive termination signal + let is_terminated = Arc::new(AtomicBool::new(false)); + let is_terminated_clone = is_terminated.clone(); + tokio::spawn(async move { + while let Some(signal) = signal_rx.recv().await { + match *signal { + DataStreamThreadSignal::Terminate(reason) => { + tracing::debug!("Received Terminate signal (reason: {:?})", reason); + is_terminated_clone.store(true, Ordering::Relaxed); + } + } + } + }); + + let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); + + let stream_object_forwarder = SubgroupStreamObjectForwarder { + stream, + senders, + downstream_subscribe_id, + downstream_track_alias, + cache_key, + subgroup_stream_id, + filter_type, + requested_object_range, + is_terminated, + sleep_time, + }; + + Ok(stream_object_forwarder) + } + + pub(crate) async fn start(&mut self) -> Result<()> { + let mut object_cache_storage = + ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); + + let upstream_forwarding_preference = self.get_upstream_forwarding_preference().await?; + self.validate_forwarding_preference(&upstream_forwarding_preference) + .await?; + + let downstream_forwarding_preference = upstream_forwarding_preference.clone(); + self.set_forwarding_preference(downstream_forwarding_preference) + .await?; + + self.forward_header(&mut object_cache_storage).await?; + + self.forward_objects(&mut object_cache_storage).await?; + + Ok(()) + } + + pub(crate) async fn finish(&mut self) -> Result<()> { + let downstream_session_id = self.stream.stable_id(); + let downstream_stream_id = self.stream.stream_id(); + self.senders + .buffer_tx() + .send(BufferCommand::ReleaseStream { + session_id: downstream_session_id, + stream_id: downstream_stream_id, + }) + .await?; + + // Send RESET_STREAM frame to the subscriber + self.stream.finish().await?; + + tracing::info!("SubgroupStreamObjectForwarder finished"); + + Ok(()) + } + + async fn get_upstream_forwarding_preference(&self) -> Result> { + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + + let upstream_session_id = self.cache_key.session_id(); + let upstream_subscribe_id = self.cache_key.subscribe_id(); + + pubsub_relation_manager + .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) + .await + } + + async fn validate_forwarding_preference( + &self, + upstream_forwarding_preference: &Option, + ) -> Result<()> { + match upstream_forwarding_preference { + Some(ForwardingPreference::Subgroup) => Ok(()), + _ => { + bail!("Forwarding preference is not Subgroup Stream"); + } + } + } + + async fn set_forwarding_preference( + &self, + downstream_forwarding_preference: Option, + ) -> Result<()> { + let forwarding_preference = downstream_forwarding_preference.unwrap(); + let downstream_session_id = self.stream.stable_id(); + let downstream_subscribe_id = self.downstream_subscribe_id; + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + + pubsub_relation_manager + .set_downstream_forwarding_preference( + downstream_session_id, + downstream_subscribe_id, + forwarding_preference, + ) + .await?; + + Ok(()) + } + + async fn forward_header( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<()> { + let upstream_header = self.get_upstream_header(object_cache_storage).await?; + + let downstream_header = self.generate_downstream_header(&upstream_header).await; + + let message_buf = self.packetize_header(&downstream_header).await?; + self.send(message_buf).await?; + + Ok(()) + } + + async fn get_upstream_header( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result { + let (group_id, subgroup_id) = self.subgroup_stream_id; + let subgroup_stream_header = object_cache_storage + .get_subgroup_stream_header(&self.cache_key, group_id, subgroup_id) + .await?; + + Ok(subgroup_stream_header) + } + + async fn forward_objects( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<()> { + let mut object_cache_id = None; + let mut is_end = false; + + while !is_end { + (object_cache_id, is_end) = self + .forward_object(object_cache_storage, object_cache_id) + .await?; + } + + Ok(()) + } + + async fn forward_object( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + cache_id: Option, + ) -> Result<(Option, bool)> { + // Do loop until get an object from the cache storage + loop { + let is_terminated = self.is_terminated.load(Ordering::Relaxed); + if is_terminated { + let is_end = true; + return Ok((cache_id, is_end)); + } + + let (cache_id, stream_object) = + match self.try_get_object(object_cache_storage, cache_id).await? { + Some((id, object)) => (id, object), + None => { + // If there is no object in the cache storage, sleep for a while and try again + thread::sleep(self.sleep_time); + continue; + } + }; + + let message_buf = self.packetize_object(&stream_object).await?; + self.send(message_buf).await?; + + let is_data_stream_ended = self.is_data_stream_ended(&stream_object); + + if is_data_stream_ended { + let stream_ids = self.get_stream_ids_for_same_group().await?; + + // Wait to forward rest of the objects on other forwarders in the same group + let send_delay_ms = Duration::from_millis(50); // FIXME: Temporary threshold + thread::sleep(send_delay_ms); + + for stream_id in stream_ids { + // Skip the stream of this forwarder + if stream_id == self.stream.stream_id() { + continue; + } + self.send_termination_signal_to_forwarder(&stream_object, stream_id) + .await?; + } + } + + let is_end = is_data_stream_ended; + + return Ok((Some(cache_id), is_end)); + } + } + + async fn try_get_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + cache_id: Option, + ) -> Result> { + match cache_id { + // Try to get the first object according to Filter Type + None => self.try_get_first_object(object_cache_storage).await, + Some(cache_id) => { + // Try to get the subsequent object with cache_id + self.try_get_subsequent_object(object_cache_storage, cache_id) + .await + } + } + } + + async fn try_get_first_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result> { + let downstream_session_id = self.stream.stable_id(); + let downstream_subscribe_id = self.downstream_subscribe_id; + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let actual_object_start = pubsub_relation_manager + .get_downstream_actual_object_start(downstream_session_id, downstream_subscribe_id) + .await?; + + match actual_object_start { + None => { + // If there is no actual start, it means that this is the first forwarder on this subscription. + let object_with_cache_id = self + .try_get_first_object_for_first_stream(object_cache_storage) + .await?; + + if object_with_cache_id.is_none() { + return Ok(None); + } + + let (cache_id, stream_object) = object_with_cache_id.unwrap(); + let group_id = self.subgroup_stream_id.0; + let object_id = stream_object.object_id(); + let actual_object_start = ObjectStart::new(group_id, object_id); + + pubsub_relation_manager + .set_downstream_actual_object_start( + downstream_session_id, + downstream_subscribe_id, + actual_object_start, + ) + .await?; + + Ok(Some((cache_id, stream_object))) + } + Some(actual_object_start) => { + // If there is an actual start, it means that this is the second or later forwarder on this subscription. + self.try_get_first_object_for_subsequent_stream( + object_cache_storage, + actual_object_start, + ) + .await + } + } + } + + async fn try_get_first_object_for_first_stream( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result> { + let (group_id, subgroup_id) = self.subgroup_stream_id; + + match self.filter_type { + FilterType::LatestGroup => { + // TODO: Remove LatestGroup since it is not exist in the draft-10 + object_cache_storage + .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) + .await + } + FilterType::LatestObject => { + // If the subscriber is the first subscriber for this track, the Relay needs to + // start sending from first object for the subscriber to decode the contents. + object_cache_storage + .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) + .await + } + FilterType::AbsoluteStart | FilterType::AbsoluteRange => { + let start_group = self.requested_object_range.start_group().unwrap(); + let start_object = self.requested_object_range.start_object().unwrap(); + + if group_id == start_group { + object_cache_storage + .get_absolute_subgroup_stream_object( + &self.cache_key, + group_id, + subgroup_id, + start_object, + ) + .await + } else { + object_cache_storage + .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) + .await + } + } + } + } + + async fn try_get_first_object_for_subsequent_stream( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + actual_object_start: ObjectStart, + ) -> Result> { + let (group_id, subgroup_id) = self.subgroup_stream_id; + + if group_id == actual_object_start.group_id() { + // If the actual start group id is the same as the group_id of this subgroup stream, + // this subgroup stream belongs same group with the first subgroup stream. + // So get the object with same object id with the first subgroup stream. + object_cache_storage + .get_absolute_subgroup_stream_object( + &self.cache_key, + group_id, + subgroup_id, + actual_object_start.object_id(), + ) + .await + } else { + // Else, this subgroup stream belongs to a later group than the first subgroup stream. + // So start from the first object in the subgroup stream. + object_cache_storage + .get_first_subgroup_stream_object(&self.cache_key, group_id, subgroup_id) + .await + } + } + + async fn try_get_subsequent_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + object_cache_id: usize, + ) -> Result> { + let (group_id, subgroup_id) = self.subgroup_stream_id; + object_cache_storage + .get_next_subgroup_stream_object( + &self.cache_key, + group_id, + subgroup_id, + object_cache_id, + ) + .await + } + + async fn generate_downstream_header( + &self, + upstream_header: &subgroup_stream::Header, + ) -> subgroup_stream::Header { + subgroup_stream::Header::new( + self.downstream_track_alias, // Replace with downstream_track_alias + upstream_header.group_id(), + upstream_header.subgroup_id(), + upstream_header.publisher_priority(), + ) + .unwrap() + } + + async fn packetize_header(&self, header: &subgroup_stream::Header) -> Result { + let downstream_session_id = self.stream.stable_id(); + let downstream_subscribe_id = self.downstream_subscribe_id; + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let downstream_track_alias = pubsub_relation_manager + .get_downstream_track_alias(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + let header = subgroup_stream::Header::new( + downstream_track_alias, + header.group_id(), + header.subgroup_id(), + header.publisher_priority(), + ) + .unwrap(); + + let mut buf = BytesMut::new(); + header.packetize(&mut buf); + + let mut message_buf = BytesMut::with_capacity(buf.len() + 8); + message_buf.extend(write_variable_integer( + u8::from(DataStreamType::SubgroupHeader) as u64, + )); + message_buf.extend(buf); + + Ok(message_buf) + } + + async fn packetize_object( + &mut self, + stream_object: &subgroup_stream::Object, + ) -> Result { + let mut buf = BytesMut::new(); + stream_object.packetize(&mut buf); + + let mut message_buf = BytesMut::with_capacity(buf.len()); + message_buf.extend(buf); + + Ok(message_buf) + } + + async fn send(&mut self, message_buf: BytesMut) -> Result<()> { + if let Err(e) = self.stream.write_all(&message_buf).await { + tracing::warn!("Failed to write to stream: {:?}", e); + bail!(e); + } + + Ok(()) + } + + // This function is implemented according to the following sentence in draft. + // A relay MAY treat receipt of EndOfGroup, EndOfTrack, GroupDoesNotExist, or + // EndOfTrack objects as a signal to close corresponding streams even if the FIN + // has not arrived, as further objects on the stream would be a protocol violation. + fn is_data_stream_ended(&self, stream_object: &subgroup_stream::Object) -> bool { + matches!( + stream_object.object_status(), + Some(ObjectStatus::EndOfTrack) + | Some(ObjectStatus::EndOfGroup) + | Some(ObjectStatus::EndOfTrackAndGroup) + ) + } + + async fn get_stream_ids_for_same_group(&self) -> Result> { + let downstream_session_id = self.stream.stable_id(); + let downstream_subscribe_id = self.downstream_subscribe_id; + let (group_id, _) = self.subgroup_stream_id; + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let subgroup_ids = pubsub_relation_manager + .get_downstream_subgroup_ids_for_group( + downstream_session_id, + downstream_subscribe_id, + group_id, + ) + .await?; + + let mut stream_ids: Vec = vec![]; + for subgroup_id in subgroup_ids { + let stream_id = pubsub_relation_manager + .get_downstream_stream_id_for_subgroup( + downstream_session_id, + downstream_subscribe_id, + group_id, + subgroup_id, + ) + .await? + .unwrap(); + + stream_ids.push(stream_id); + } + + Ok(stream_ids) + } + + async fn send_termination_signal_to_forwarder( + &self, + object: &subgroup_stream::Object, + stream_id: u64, + ) -> Result<()> { + let downstream_session_id = self.stream.stable_id(); + let object_status = object.object_status().unwrap(); + + let signal_dispatcher = SignalDispatcher::new(self.senders.signal_dispatch_tx().clone()); + + tracing::debug!( + "Send termination signal to downstream session: {}, stream: {}", + downstream_session_id, + stream_id + ); + + let terminate_reason = TerminateReason::ObjectStatus(object_status); + let signal = Box::new(DataStreamThreadSignal::Terminate(terminate_reason)); + signal_dispatcher + .transfer_signal_to_data_stream_thread(downstream_session_id, stream_id, signal) + .await?; + + Ok(()) + } +} diff --git a/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/receiver.rs b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/receiver.rs new file mode 100644 index 00000000..b98049c4 --- /dev/null +++ b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/receiver.rs @@ -0,0 +1,660 @@ +use super::uni_stream::UniRecvStream; +use crate::{ + modules::{ + buffer_manager::{request_buffer, BufferCommand}, + message_handlers::{ + subgroup_stream_header::{self, SubgroupStreamHeaderProcessResult}, + subgroup_stream_object::{self, SubgroupStreamObjectProcessResult}, + }, + moqt_client::MOQTClient, + object_cache_storage::{ + cache::{CacheKey, SubgroupStreamId}, + wrapper::ObjectCacheStorageWrapper, + }, + pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, + server_processes::senders::Senders, + }, + signal_dispatcher::{DataStreamThreadSignal, SignalDispatcher, TerminateReason}, + TerminationError, +}; +use anyhow::Result; +use bytes::BytesMut; +use moqt_core::{ + constants::TerminationErrorCode, + data_stream_type::DataStreamType, + messages::{ + control_messages::subscribe::FilterType, + data_streams::{object_status::ObjectStatus, subgroup_stream}, + }, + models::{range::ObjectRange, tracks::ForwardingPreference}, + pubsub_relation_manager_repository::PubSubRelationManagerRepository, +}; +use std::{sync::Arc, thread, time::Duration}; +use tokio::sync::{mpsc, Mutex}; +use tracing::{self}; + +pub(crate) struct SubgroupStreamObjectReceiver { + stream: UniRecvStream, + buf: Arc>, + senders: Arc, + client: Arc>, + duration: u64, + signal_rx: Arc>>>, + subscribe_id: Option, + subgroup_stream_id: Option, + filter_type: Option, + requested_object_range: Option, +} + +impl SubgroupStreamObjectReceiver { + pub(crate) async fn init( + stream: UniRecvStream, + client: Arc>, + signal_rx: mpsc::Receiver>, + ) -> Self { + let senders = client.lock().await.senders(); + let stable_id = stream.stable_id(); + let stream_id = stream.stream_id(); + let buf = request_buffer(senders.buffer_tx().clone(), stable_id, stream_id).await; + // TODO: Set the accurate duration + let duration = 100000; + let signal_rx = Arc::new(Mutex::new(signal_rx)); + + SubgroupStreamObjectReceiver { + stream, + buf, + senders, + client, + duration, + signal_rx, + subscribe_id: None, + subgroup_stream_id: None, + filter_type: None, + requested_object_range: None, + } + } + + pub(crate) async fn start(&mut self) -> Result<(), TerminationError> { + let mut object_cache_storage = + ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); + + let mut is_end = false; + let session_id = self.client.lock().await.id(); + + while !is_end { + let signal_rx = self.signal_rx.clone(); + let mut signal_rx = signal_rx.lock().await; + + tokio::select! { + read_bytes = self.read_stream() => { + + let read_bytes = read_bytes?; + + self.add_to_buf(read_bytes).await; + + if !self.has_received_header() { + self.receive_header(session_id, &mut object_cache_storage) + .await?; + + // If the header has not been received, continue to receive the header. + if !self.has_received_header() { + continue; + } + } + + is_end = self.receive_objects(&mut object_cache_storage).await?; + }, + Some(signal) = signal_rx.recv() => { + match *signal { + DataStreamThreadSignal::Terminate(reason) => { + tracing::debug!("Received Terminate signal (reason: {:?})", reason); + break; + } + } + } + } + } + + Ok(()) + } + + pub(crate) async fn finish(self) -> Result<()> { + self.senders + .buffer_tx() + .send(BufferCommand::ReleaseStream { + session_id: self.stream.stable_id(), + stream_id: self.stream.stream_id(), + }) + .await?; + + // Send STOP_SENDING frame to the publisher + self.stream.stop(); + + tracing::debug!("SubgroupStreamObjectReceiver finished"); + + Ok(()) + } + + async fn read_stream(&mut self) -> Result { + // Align with the stream_receive_window configured on the MoQT Server + let mut buffer = vec![0; 10 * 1024 * 1024].into_boxed_slice(); + + let length: usize = match self.stream.read(&mut buffer).await { + Ok(byte_read) => byte_read.unwrap(), + Err(err) => { + let msg = format!("Failed to read from stream: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + + Ok(BytesMut::from(&buffer[..length])) + } + + async fn add_to_buf(&mut self, read_buf: BytesMut) { + let mut buf = self.buf.lock().await; + buf.extend_from_slice(&read_buf); + } + + fn has_received_header(&self) -> bool { + self.subscribe_id.is_some() + } + + async fn receive_header( + &mut self, + session_id: usize, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<(), TerminationError> { + let header = match self.read_header_from_buf().await? { + Some(header) => header, + None => { + return Ok(()); + } + }; + + let subscribe_id = self + .get_subscribe_id(session_id, header.track_alias()) + .await?; + self.subscribe_id = Some(subscribe_id); + self.subgroup_stream_id = Some((header.group_id(), header.subgroup_id())); + + self.set_upstream_stream_id(session_id).await?; + self.set_upstream_forwarding_preference(session_id).await?; + + let filter_type = self.get_upstream_filter_type(session_id).await?; + self.filter_type = Some(filter_type); + let requested_object_range = self.get_upstream_requested_object_range(session_id).await?; + self.requested_object_range = Some(requested_object_range); + + self.create_cache_storage(session_id, header, object_cache_storage) + .await?; + + self.create_forwarders(session_id).await?; + + Ok(()) + } + + async fn read_header_from_buf( + &self, + ) -> Result, TerminationError> { + let result = self.try_read_header_from_buf().await; + + match result { + SubgroupStreamHeaderProcessResult::Success(stream_header) => Ok(Some(stream_header)), + SubgroupStreamHeaderProcessResult::Continue => Ok(None), + SubgroupStreamHeaderProcessResult::Failure(code, reason) => { + let msg = std::format!("stream_header_read failure: {:?}", reason); + Err((code, msg)) + } + } + } + + async fn try_read_header_from_buf(&self) -> SubgroupStreamHeaderProcessResult { + let mut process_buf = self.buf.lock().await; + let client = self.client.clone(); + + subgroup_stream_header::try_read_header(&mut process_buf, client).await + } + + async fn get_subscribe_id( + &self, + session_id: usize, + track_alias: u64, + ) -> Result { + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + match pubsub_relation_manager + .get_upstream_subscribe_id_by_track_alias(session_id, track_alias) + .await + { + Ok(Some(subscribe_id)) => Ok(subscribe_id), + Ok(None) => { + let msg = "Subscribe id is not found".to_string(); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + Err(err) => { + let msg = format!("Fail to get subscribe id: {:?}", err); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } + + async fn set_upstream_stream_id( + &self, + upstream_session_id: usize, + ) -> Result<(), TerminationError> { + // Register stream_id to send signal to other subgroup receiver threads in the same group + let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); + let upstream_subscribe_id = self.subscribe_id.unwrap(); + let stream_id = self.stream.stream_id(); + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + match pubsub_relation_manager + .set_upstream_stream_id( + upstream_session_id, + upstream_subscribe_id, + group_id, + subgroup_id, + stream_id, + ) + .await + { + Ok(_) => Ok(()), + Err(err) => { + let msg = format!("Fail to set upstream stream id: {:?}", err); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } + + async fn set_upstream_forwarding_preference( + &self, + upstream_session_id: usize, + ) -> Result<(), TerminationError> { + let forwarding_preference = ForwardingPreference::Subgroup; + let upstream_subscribe_id = self.subscribe_id.unwrap(); + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + match pubsub_relation_manager + .set_upstream_forwarding_preference( + upstream_session_id, + upstream_subscribe_id, + forwarding_preference, + ) + .await + { + Ok(_) => Ok(()), + Err(err) => { + let msg = format!("Fail to set upstream forwarding preference: {:?}", err); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } + + async fn get_upstream_filter_type( + &self, + upstream_session_id: usize, + ) -> Result { + let upstream_subscribe_id = self.subscribe_id.unwrap(); + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + match pubsub_relation_manager + .get_upstream_filter_type(upstream_session_id, upstream_subscribe_id) + .await + { + Ok(Some(filter_type)) => Ok(filter_type), + Ok(None) => { + let msg = "Filter type is not found".to_string(); + let code = TerminationErrorCode::InternalError; + Err((code, msg)) + } + Err(err) => { + let msg = format!("Fail to get upstream filter type: {:?}", err); + let code = TerminationErrorCode::InternalError; + Err((code, msg)) + } + } + } + + async fn get_upstream_requested_object_range( + &mut self, + upstream_session_id: usize, + ) -> Result { + let upstream_subscribe_id = self.subscribe_id.unwrap(); + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + match pubsub_relation_manager + .get_upstream_requested_object_range(upstream_session_id, upstream_subscribe_id) + .await + { + Ok(Some(range)) => Ok(range), + Ok(None) => { + let msg = "Requested range is not found".to_string(); + let code = TerminationErrorCode::InternalError; + Err((code, msg)) + } + Err(err) => { + let msg = format!("Fail to get upstream requested range: {:?}", err); + let code = TerminationErrorCode::InternalError; + Err((code, msg)) + } + } + } + + async fn create_cache_storage( + &self, + upstream_session_id: usize, + + stream_header: subgroup_stream::Header, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<(), TerminationError> { + let upstream_subscribe_id = self.subscribe_id.unwrap(); + let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); + + let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); + let result = object_cache_storage + .create_subgroup_stream_cache(&cache_key, group_id, subgroup_id, stream_header) + .await; + match result { + Ok(_) => Ok(()), + Err(err) => { + let msg = format!("Fail to create cache storage: {:?}", err); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } + + async fn create_forwarders(&self, upstream_session_id: usize) -> Result<(), TerminationError> { + let upstream_subscribe_id = self.subscribe_id.unwrap(); + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let subscribers = match pubsub_relation_manager + .get_related_subscribers(upstream_session_id, upstream_subscribe_id) + .await + { + Ok(subscribers) => subscribers, + Err(err) => { + let msg = format!("Fail to get related subscribers: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + + for (downstream_session_id, downstream_subscribe_id) in subscribers { + match self + .create_forwarder(downstream_session_id, downstream_subscribe_id) + .await + { + Ok(_) => {} + Err(err) => { + let msg = format!("Fail to create forwarder: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + } + } + Ok(()) + } + + async fn create_forwarder( + &self, + downstream_session_id: usize, + downstream_subscribe_id: u64, + ) -> Result<()> { + let start_forwarder_txes = self.senders.start_forwarder_txes(); + let data_stream_type = DataStreamType::SubgroupHeader; + + let start_forwarder_tx = start_forwarder_txes + .lock() + .await + .get(&downstream_session_id) + .unwrap() + .clone(); + + start_forwarder_tx + .send(( + downstream_subscribe_id, + data_stream_type, + self.subgroup_stream_id, + )) + .await?; + + Ok(()) + } + + async fn receive_objects( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result { + let session_id = self.client.lock().await.id(); + let subscribe_id = self.subscribe_id.unwrap(); + let mut is_end = false; + + while !is_end { + is_end = match self + .receive_object(session_id, subscribe_id, object_cache_storage) + .await? + { + Some(is_end) => is_end, + None => break, // Return to read stream again since there is no object in the buffer. + }; + } + + Ok(is_end) + } + + async fn receive_object( + &self, + session_id: usize, + subscribe_id: u64, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result, TerminationError> { + let stream_object = match self.read_object_from_buf().await? { + Some(object) => object, + None => { + return Ok(None); + } + }; + + self.store_object( + &stream_object, + session_id, + subscribe_id, + object_cache_storage, + ) + .await?; + + let is_data_stream_ended = self.is_data_stream_ended(&stream_object); + + if is_data_stream_ended { + let stream_ids = self.get_stream_ids_for_same_group().await?; + + // Wait to forward rest of the objects on other receivers in the same group + let send_delay_ms = Duration::from_millis(50); // FIXME: Temporary threshold + thread::sleep(send_delay_ms); + + for stream_id in stream_ids { + // Skip the stream of this receiver + if stream_id == self.stream.stream_id() { + continue; + } + self.send_termination_signal_to_receiver(&stream_object, stream_id) + .await?; + } + } + + let is_end = is_data_stream_ended; + + Ok(Some(is_end)) + } + + async fn read_object_from_buf( + &self, + ) -> Result, TerminationError> { + let result = self.try_read_object_from_buf().await; + + match result { + SubgroupStreamObjectProcessResult::Success(stream_object) => Ok(Some(stream_object)), + SubgroupStreamObjectProcessResult::Continue => Ok(None), + } + } + + async fn try_read_object_from_buf(&self) -> SubgroupStreamObjectProcessResult { + let mut buf = self.buf.lock().await; + + subgroup_stream_object::try_read_object(&mut buf).await + } + + async fn store_object( + &self, + stream_object: &subgroup_stream::Object, + upstream_session_id: usize, + upstream_subscribe_id: u64, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<(), TerminationError> { + let cache_key = CacheKey::new(upstream_session_id, upstream_subscribe_id); + let (group_id, subgroup_id) = self.subgroup_stream_id.unwrap(); + + match object_cache_storage + .set_subgroup_stream_object( + &cache_key, + group_id, + subgroup_id, + stream_object.clone(), + self.duration, + ) + .await + { + Ok(_) => Ok(()), + Err(err) => { + let msg = format!( + "Fail to store subgroup stream object to cache storage: {:?}", + err + ); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } + + // This function is implemented according to the following sentence in draft. + // A relay MAY treat receipt of EndOfGroup, EndOfTrack, GroupDoesNotExist, or + // EndOfTrack objects as a signal to close corresponding streams even if the FIN + // has not arrived, as further objects on the stream would be a protocol violation. + // TODO: Add handling for FIN message + fn is_data_stream_ended(&self, stream_object: &subgroup_stream::Object) -> bool { + matches!( + stream_object.object_status(), + Some(ObjectStatus::EndOfTrack) + | Some(ObjectStatus::EndOfGroup) + | Some(ObjectStatus::EndOfTrackAndGroup) + ) + } + + async fn get_stream_ids_for_same_group(&self) -> Result, TerminationError> { + let upstream_session_id = self.stream.stable_id(); + let upstream_subscribe_id = self.subscribe_id.unwrap(); + let (group_id, _) = self.subgroup_stream_id.unwrap(); + + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let subgroup_ids = match pubsub_relation_manager + .get_upstream_subgroup_ids_for_group( + upstream_session_id, + upstream_subscribe_id, + group_id, + ) + .await + { + Ok(subgroup_ids) => subgroup_ids, + Err(err) => { + let msg = format!("Fail to get upstream subgroup ids by group: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + + let mut stream_ids: Vec = vec![]; + for subgroup_id in subgroup_ids { + let stream_id = match pubsub_relation_manager + .get_upstream_stream_id_for_subgroup( + upstream_session_id, + upstream_subscribe_id, + group_id, + subgroup_id, + ) + .await + { + Ok(Some(stream_id)) => stream_id, + Ok(None) => { + let msg = "Stream id is not found".to_string(); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + Err(err) => { + let msg = format!("Fail to get upstream stream id by subgroup: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + + stream_ids.push(stream_id); + } + + Ok(stream_ids) + } + + async fn send_termination_signal_to_receiver( + &self, + object: &subgroup_stream::Object, + stream_id: u64, + ) -> Result<(), TerminationError> { + let upstream_session_id = self.stream.stable_id(); + let object_status = object.object_status().unwrap(); + + let signal_dispatcher = SignalDispatcher::new(self.senders.signal_dispatch_tx().clone()); + + tracing::debug!( + "Send termination signal to upstream session: {}, stream: {}", + upstream_session_id, + stream_id + ); + + let terminate_reason = TerminateReason::ObjectStatus(object_status); + let signal = Box::new(DataStreamThreadSignal::Terminate(terminate_reason)); + match signal_dispatcher + .transfer_signal_to_data_stream_thread(upstream_session_id, stream_id, signal) + .await + { + Ok(_) => Ok(()), + Err(err) => { + let msg = format!("Fail to send termination signal: {:?}", err); + let code = TerminationErrorCode::InternalError; + + Err((code, msg)) + } + } + } +} diff --git a/moqt-server/src/modules/server_processes/data_streams/stream/uni_stream.rs b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/uni_stream.rs similarity index 79% rename from moqt-server/src/modules/server_processes/data_streams/stream/uni_stream.rs rename to moqt-server/src/modules/server_processes/data_streams/subgroup_stream/uni_stream.rs index dcba4d79..9184a989 100644 --- a/moqt-server/src/modules/server_processes/data_streams/stream/uni_stream.rs +++ b/moqt-server/src/modules/server_processes/data_streams/subgroup_stream/uni_stream.rs @@ -1,6 +1,6 @@ use wtransport::{ error::{StreamReadError, StreamWriteError}, - RecvStream, SendStream, + RecvStream, SendStream, VarInt, }; pub(crate) struct UniRecvStream { @@ -32,6 +32,13 @@ impl UniRecvStream { ) -> Result, StreamReadError> { self.recv_stream.read(buffer).await } + + pub(crate) fn stop(self) { + // Use code 0 for normal termination + // TODO: Use accurate error code + let code = VarInt::from_u32(0); + self.recv_stream.stop(code); + } } pub(crate) struct UniSendStream { @@ -60,4 +67,8 @@ impl UniSendStream { pub(crate) async fn write_all(&mut self, buffer: &[u8]) -> Result<(), StreamWriteError> { self.send_stream.write_all(buffer).await } + + pub(crate) async fn finish(&mut self) -> Result<(), StreamWriteError> { + self.send_stream.finish().await + } } diff --git a/moqt-server/src/modules/server_processes/senders.rs b/moqt-server/src/modules/server_processes/senders.rs index c0d3a99e..7d9881bb 100644 --- a/moqt-server/src/modules/server_processes/senders.rs +++ b/moqt-server/src/modules/server_processes/senders.rs @@ -1,8 +1,9 @@ use crate::{ modules::{ - buffer_manager::BufferCommand, object_cache_storage::commands::ObjectCacheStorageCommand, + buffer_manager::BufferCommand, control_message_dispatcher::ControlMessageDispatchCommand, + object_cache_storage::commands::ObjectCacheStorageCommand, pubsub_relation_manager::commands::PubSubRelationCommand, - send_stream_dispatcher::SendStreamDispatchCommand, + signal_dispatcher::SignalDispatchCommand, }, SenderToOpenSubscription, }; @@ -39,7 +40,8 @@ impl SenderToOtherConnectionThread { pub(crate) struct SendersToManagementThread { buffer_tx: mpsc::Sender, pubsub_relation_tx: mpsc::Sender, - send_stream_tx: mpsc::Sender, + control_message_dispatch_tx: mpsc::Sender, + signal_dispatch_tx: mpsc::Sender, object_cache_tx: mpsc::Sender, } @@ -47,13 +49,15 @@ impl SendersToManagementThread { pub(crate) fn new( buffer_tx: mpsc::Sender, pubsub_relation_tx: mpsc::Sender, - send_stream_tx: mpsc::Sender, + control_message_dispatch_tx: mpsc::Sender, + signal_dispatch_tx: mpsc::Sender, object_cache_tx: mpsc::Sender, ) -> Self { SendersToManagementThread { buffer_tx, pubsub_relation_tx, - send_stream_tx, + control_message_dispatch_tx, + signal_dispatch_tx, object_cache_tx, } } @@ -97,8 +101,16 @@ impl Senders { &self.senders_to_management_thread.pubsub_relation_tx } - pub(crate) fn send_stream_tx(&self) -> &mpsc::Sender { - &self.senders_to_management_thread.send_stream_tx + pub(crate) fn control_message_dispatch_tx( + &self, + ) -> &mpsc::Sender { + &self + .senders_to_management_thread + .control_message_dispatch_tx + } + + pub(crate) fn signal_dispatch_tx(&self) -> &mpsc::Sender { + &self.senders_to_management_thread.signal_dispatch_tx } pub(crate) fn object_cache_tx(&self) -> &mpsc::Sender { @@ -122,12 +134,14 @@ pub(crate) mod test_helper_fn { let (buffer_tx, _) = tokio::sync::mpsc::channel(1); let (pubsub_relation_tx, _) = tokio::sync::mpsc::channel(1); - let (send_stream_tx, _) = tokio::sync::mpsc::channel(1); + let (control_message_dispatch_tx, _) = tokio::sync::mpsc::channel(1); + let (signal_dispatch_tx, _) = tokio::sync::mpsc::channel(1); let (object_cache_tx, _) = tokio::sync::mpsc::channel(1); let senders_to_management_thread = super::SendersToManagementThread::new( buffer_tx, pubsub_relation_tx, - send_stream_tx, + control_message_dispatch_tx, + signal_dispatch_tx, object_cache_tx, ); diff --git a/moqt-server/src/modules/server_processes/session_handler.rs b/moqt-server/src/modules/server_processes/session_handler.rs index a788b08d..c51759e7 100644 --- a/moqt-server/src/modules/server_processes/session_handler.rs +++ b/moqt-server/src/modules/server_processes/session_handler.rs @@ -2,16 +2,17 @@ use super::senders::{SenderToOtherConnectionThread, SendersToManagementThread}; use crate::{ modules::{ buffer_manager::BufferCommand, + control_message_dispatcher::ControlMessageDispatchCommand, moqt_client::MOQTClient, object_cache_storage::wrapper::ObjectCacheStorageWrapper, pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, - send_stream_dispatcher::SendStreamDispatchCommand, server_processes::{ senders::{SenderToSelf, Senders}, thread_starters::select_spawn_thread, }, }, - SubgroupStreamId, + signal_dispatcher::{DataStreamThreadSignal, SignalDispatcher, TerminateReason}, + SignalDispatchCommand, SubgroupStreamId, }; use anyhow::Result; use moqt_core::{ @@ -110,9 +111,24 @@ impl SessionHandler { let senders = self.client.lock().await.senders(); let stable_id = self.client.lock().await.id(); - // Delete pub/sub information related to the client let pubsub_relation_manager = PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); + + let stream_ids = self + .get_all_stream_ids(&pubsub_relation_manager, stable_id) + .await?; + + self.send_terminate_signal_to_data_stream_threads(&senders, stable_id, stream_ids) + .await?; + + senders + .signal_dispatch_tx() + .send(SignalDispatchCommand::Delete { + session_id: stable_id, + }) + .await?; + + // Delete pub/sub information related to the client let _ = pubsub_relation_manager.delete_client(stable_id).await; // Delete object cache related to the client @@ -124,8 +140,8 @@ impl SessionHandler { // Delete senders to the client senders - .send_stream_tx() - .send(SendStreamDispatchCommand::Delete { + .control_message_dispatch_tx() + .send(ControlMessageDispatchCommand::Delete { session_id: stable_id, }) .await?; @@ -138,8 +154,218 @@ impl SessionHandler { }) .await?; + // Delete senders for data stream threads + senders + .signal_dispatch_tx() + .send(SignalDispatchCommand::Delete { + session_id: stable_id, + }) + .await?; + tracing::info!("SessionHandler finished"); Ok(()) } + + async fn get_all_stream_ids( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + ) -> Result> { + let upstream_stream_ids = self + .get_upstream_stream_ids(pubsub_relation_manager, stable_id) + .await?; + let downstream_stream_ids = self + .get_downstream_stream_ids(pubsub_relation_manager, stable_id) + .await?; + + let mut stream_ids = Vec::new(); + stream_ids.extend(upstream_stream_ids); + stream_ids.extend(downstream_stream_ids); + + Ok(stream_ids) + } + + async fn get_upstream_stream_ids( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + ) -> Result> { + let upstream_subscribe_ids = pubsub_relation_manager + .get_upstream_subscribe_ids_for_client(stable_id) + .await?; + + let mut stream_ids = Vec::new(); + + for subscribe_id in upstream_subscribe_ids { + let stream_ids_for_subscription = self + .get_upstream_stream_ids_for_subscription( + pubsub_relation_manager, + stable_id, + subscribe_id, + ) + .await?; + stream_ids.extend(stream_ids_for_subscription); + } + + Ok(stream_ids) + } + + async fn get_upstream_stream_ids_for_subscription( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + subscribe_id: u64, + ) -> Result> { + let group_ids = pubsub_relation_manager + .get_upstream_group_ids_for_subscription(stable_id, subscribe_id) + .await?; + + let mut stream_ids = Vec::new(); + + for group_id in group_ids { + let stream_ids_for_group = self + .get_upstream_stream_ids_for_group( + pubsub_relation_manager, + stable_id, + subscribe_id, + group_id, + ) + .await?; + + stream_ids.extend(stream_ids_for_group); + } + + Ok(stream_ids) + } + + async fn get_upstream_stream_ids_for_group( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + subscribe_id: u64, + group_id: u64, + ) -> Result> { + let subgroup_ids = pubsub_relation_manager + .get_upstream_subgroup_ids_for_group(stable_id, subscribe_id, group_id) + .await?; + + let mut stream_ids = Vec::new(); + + for subgroup_id in subgroup_ids { + let stream_id = pubsub_relation_manager + .get_upstream_stream_id_for_subgroup(stable_id, subscribe_id, group_id, subgroup_id) + .await?; + + if let Some(stream_id) = stream_id { + stream_ids.push(stream_id); + } + } + + Ok(stream_ids) + } + + async fn get_downstream_stream_ids( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + ) -> Result> { + let downstream_subscribe_ids = pubsub_relation_manager + .get_downstream_subscribe_ids_for_client(stable_id) + .await?; + + let mut stream_ids = Vec::new(); + + for subscribe_id in downstream_subscribe_ids { + let stream_ids_for_subscription = self + .get_downstream_stream_ids_for_subscription( + pubsub_relation_manager, + stable_id, + subscribe_id, + ) + .await?; + stream_ids.extend(stream_ids_for_subscription); + } + + Ok(stream_ids) + } + + async fn get_downstream_stream_ids_for_subscription( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + subscribe_id: u64, + ) -> Result> { + let group_ids = pubsub_relation_manager + .get_downstream_group_ids_for_subscription(stable_id, subscribe_id) + .await?; + + let mut stream_ids = Vec::new(); + + for group_id in group_ids { + let stream_ids_for_group = self + .get_downstream_stream_ids_for_group( + pubsub_relation_manager, + stable_id, + subscribe_id, + group_id, + ) + .await?; + + stream_ids.extend(stream_ids_for_group); + } + + Ok(stream_ids) + } + + async fn get_downstream_stream_ids_for_group( + &self, + pubsub_relation_manager: &PubSubRelationManagerWrapper, + stable_id: usize, + subscribe_id: u64, + group_id: u64, + ) -> Result> { + let subgroup_ids = pubsub_relation_manager + .get_downstream_subgroup_ids_for_group(stable_id, subscribe_id, group_id) + .await?; + + let mut stream_ids = Vec::new(); + + for subgroup_id in subgroup_ids { + let stream_id = pubsub_relation_manager + .get_downstream_stream_id_for_subgroup( + stable_id, + subscribe_id, + group_id, + subgroup_id, + ) + .await?; + + if let Some(stream_id) = stream_id { + stream_ids.push(stream_id); + } + } + + Ok(stream_ids) + } + + async fn send_terminate_signal_to_data_stream_threads( + &self, + senders: &Senders, + stable_id: usize, + stream_ids: Vec, + ) -> Result<()> { + let signal_dispatcher = SignalDispatcher::new(senders.signal_dispatch_tx().clone()); + + let terminate_reason = TerminateReason::SessionClosed; + let signal = Box::new(DataStreamThreadSignal::Terminate(terminate_reason)); + + for stream_id in stream_ids { + signal_dispatcher + .transfer_signal_to_data_stream_thread(stable_id, stream_id, signal.clone()) + .await?; + } + + Ok(()) + } } diff --git a/moqt-server/src/modules/server_processes/thread_starters.rs b/moqt-server/src/modules/server_processes/thread_starters.rs index b1f28a44..294e377d 100644 --- a/moqt-server/src/modules/server_processes/thread_starters.rs +++ b/moqt-server/src/modules/server_processes/thread_starters.rs @@ -4,21 +4,21 @@ use super::{ }, data_streams::{ datagram::{forwarder::DatagramObjectForwarder, receiver::DatagramObjectReceiver}, - stream::{ - forwarder::StreamObjectForwarder, - receiver::StreamObjectReceiver, + subgroup_stream::{ + forwarder::SubgroupStreamObjectForwarder, + receiver::SubgroupStreamObjectReceiver, uni_stream::{UniRecvStream, UniSendStream}, }, }, }; use crate::{ - modules::{moqt_client::MOQTClient, send_stream_dispatcher::SendStreamDispatchCommand}, - SubgroupStreamId, + modules::{control_message_dispatcher::ControlMessageDispatchCommand, moqt_client::MOQTClient}, + signal_dispatcher::DataStreamThreadSignal, + SignalDispatchCommand, SubgroupStreamId, }; use anyhow::{bail, Result}; use moqt_core::{ - constants::{StreamDirection, TerminationErrorCode}, - data_stream_type::DataStreamType, + constants::TerminationErrorCode, data_stream_type::DataStreamType, messages::moqt_payload::MOQTPayload, }; use std::sync::Arc; @@ -55,10 +55,9 @@ async fn spawn_control_stream_threads( let (message_tx, message_rx) = mpsc::channel::>>(1024); senders - .send_stream_tx() - .send(SendStreamDispatchCommand::Set { + .control_message_dispatch_tx() + .send(ControlMessageDispatchCommand::Set { session_id: stable_id, - stream_direction: StreamDirection::Bi, sender: message_tx, }) .await?; @@ -95,7 +94,7 @@ async fn spawn_control_stream_threads( Ok(()) } -async fn spawn_stream_object_receiver_thread( +async fn spawn_subgroup_stream_object_receiver_thread( client: Arc>, recv_stream: RecvStream, ) -> Result<()> { @@ -105,14 +104,27 @@ async fn spawn_stream_object_receiver_thread( tracing::info!("Accepted uni-directional recv stream"); }); let stream_id = recv_stream.id().into_u64(); + let (signal_tx, signal_rx) = mpsc::channel::>(1024); + + let senders = client.lock().await.senders(); + senders + .signal_dispatch_tx() + .send(SignalDispatchCommand::Set { + session_id: stable_id, + stream_id, + sender: signal_tx, + }) + .await + .unwrap(); tokio::spawn( async move { let stream = UniRecvStream::new(stable_id, stream_id, recv_stream); let senders = client.lock().await.senders(); - let mut stream_object_receiver = StreamObjectReceiver::init(stream, client) - .instrument(session_span.clone()) - .await; + let mut stream_object_receiver = + SubgroupStreamObjectReceiver::init(stream, client, signal_rx) + .instrument(session_span.clone()) + .await; match stream_object_receiver .start() @@ -140,34 +152,42 @@ async fn spawn_stream_object_receiver_thread( Ok(()) } -async fn spawn_stream_object_forwarder_thread( +async fn spawn_subgroup_stream_object_forwarder_thread( client: Arc>, send_stream: SendStream, subscribe_id: u64, - data_stream_type: DataStreamType, - subgroup_stream_id: Option, + subgroup_stream_id: SubgroupStreamId, ) -> Result<()> { let stable_id = client.lock().await.id(); let session_span = tracing::info_span!("Session", stable_id); session_span.in_scope(|| { - tracing::info!( - "Open uni-directional send for stream type: {:?}", - data_stream_type - ); + tracing::info!("Open uni-directional send for subgroup stream",); }); let stream_id = send_stream.id().into_u64(); + let (signal_tx, signal_rx) = mpsc::channel::>(1024); + + let senders = client.lock().await.senders(); + senders + .signal_dispatch_tx() + .send(SignalDispatchCommand::Set { + session_id: stable_id, + stream_id, + sender: signal_tx, + }) + .await + .unwrap(); tokio::spawn( async move { let stream = UniSendStream::new(stable_id, stream_id, send_stream); let senders = client.lock().await.senders(); - let mut stream_object_forwarder = StreamObjectForwarder::init( + let mut stream_object_forwarder = SubgroupStreamObjectForwarder::init( stream, subscribe_id, client, - data_stream_type, subgroup_stream_id, + signal_rx, ) .instrument(session_span.clone()) .await @@ -306,7 +326,7 @@ pub(crate) async fn select_spawn_thread( }, stream = session.accept_uni() => { let recv_stream = stream?; - spawn_stream_object_receiver_thread(client.clone(), recv_stream).await?; + spawn_subgroup_stream_object_receiver_thread(client.clone(), recv_stream).await?; }, datagram = session.receive_datagram() => { let datagram = datagram?; @@ -315,15 +335,19 @@ pub(crate) async fn select_spawn_thread( // Waiting for requests to open a new data stream thread Some((subscribe_id, data_stream_type, subgroup_stream_id)) = start_forwarder_rx.recv() => { match data_stream_type { - DataStreamType::StreamHeaderTrack | DataStreamType::StreamHeaderSubgroup => { + DataStreamType::SubgroupHeader => { let send_stream = session.open_uni().await?.await?; - spawn_stream_object_forwarder_thread(client.clone(), send_stream, subscribe_id, data_stream_type, subgroup_stream_id).await?; + let subgroup_stream_id = subgroup_stream_id.unwrap(); + spawn_subgroup_stream_object_forwarder_thread(client.clone(), send_stream, subscribe_id, subgroup_stream_id).await?; } - DataStreamType::ObjectDatagram => { + DataStreamType::ObjectDatagram | DataStreamType::ObjectDatagramStatus => { let session = session.clone(); spawn_datagram_object_forwarder_thread(client.clone(), session, subscribe_id).await?; } + DataStreamType::FetchHeader => { + unimplemented!(); + } } }, // TODO: Not implemented yet diff --git a/moqt-server/src/modules/signal_dispatcher.rs b/moqt-server/src/modules/signal_dispatcher.rs new file mode 100644 index 00000000..8ff1a469 --- /dev/null +++ b/moqt-server/src/modules/signal_dispatcher.rs @@ -0,0 +1,113 @@ +use anyhow::Result; +use moqt_core::messages::data_streams::object_status::ObjectStatus; +use std::collections::HashMap; +use tokio::sync::{mpsc, oneshot}; +type SenderToDataStreamThread = mpsc::Sender>; + +#[derive(Debug, Clone)] +pub(crate) enum DataStreamThreadSignal { + Terminate(TerminateReason), +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub(crate) enum TerminateReason { + ObjectStatus(ObjectStatus), + SessionClosed, +} + +#[derive(Debug)] +pub(crate) enum SignalDispatchCommand { + Set { + session_id: usize, + stream_id: u64, + sender: SenderToDataStreamThread, + }, + Get { + session_id: usize, + stream_id: u64, + resp: oneshot::Sender>, + }, + Delete { + session_id: usize, + }, +} + +pub(crate) async fn signal_dispatcher(rx: &mut mpsc::Receiver) { + tracing::trace!("signal_dispatcher start"); + // { + // "${session_id}" : { + // "${stream_id}" + // tx + // } + // } + let mut dispatcher = HashMap::>::new(); + + while let Some(cmd) = rx.recv().await { + tracing::debug!("command received: {:#?}", cmd); + match cmd { + SignalDispatchCommand::Set { + session_id, + stream_id, + sender, + } => { + let inner_map = dispatcher.entry(session_id).or_default(); + inner_map.insert(stream_id, sender); + tracing::debug!("set: {:?}", session_id); + } + SignalDispatchCommand::Get { + session_id, + stream_id, + resp, + } => { + let sender = dispatcher + .get(&session_id) + .and_then(|inner_map| inner_map.get(&stream_id).cloned()); + + tracing::debug!("get: {:?}", sender); + let _ = resp.send(sender); + } + SignalDispatchCommand::Delete { session_id } => { + dispatcher.remove(&session_id); + tracing::debug!("delete: {:?}", session_id); + } + } + } + + tracing::trace!("signal_dispatcher end"); +} + +#[derive(Clone)] +pub(crate) struct SignalDispatcher { + tx: mpsc::Sender, +} + +impl SignalDispatcher { + pub fn new(tx: mpsc::Sender) -> Self { + Self { tx } + } +} + +impl SignalDispatcher { + pub(crate) async fn transfer_signal_to_data_stream_thread( + &self, + session_id: usize, + stream_id: u64, + signal: Box, + ) -> Result<()> { + let (resp_tx, resp_rx) = oneshot::channel::>(); + + let cmd = SignalDispatchCommand::Get { + session_id, + stream_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let sender = resp_rx + .await? + .ok_or_else(|| anyhow::anyhow!("sender not found"))?; + let _ = sender.send(signal).await; + Ok(()) + } +} diff --git a/scripts/start-localhost-test-chrome.sh b/scripts/start-localhost-test-chrome.sh index e8c5d9c0..b479fcf4 100755 --- a/scripts/start-localhost-test-chrome.sh +++ b/scripts/start-localhost-test-chrome.sh @@ -5,6 +5,6 @@ # LICENSE file in the root directory of this source tree. # Get base 64 of scert cmd -certbase64=$(eval "openssl x509 -pubkey -noout -in ./moqt-server-sample/keys/cert.pem | openssl rsa -pubin -outform der | openssl dgst -sha256 -binary | base64") - -/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --test-type --origin-to-force-quic-on=localhost:4433 --ignore-certificate-errors-spki-list=$certbase64 --use-fake-device-for-media-stream \ No newline at end of file +certbase64=$(eval "openssl x509 -pubkey -noout -in ./moqt-server-sample/keys/cert.pem | openssl pkey -pubin -outform der | openssl dgst -sha256 -binary | openssl enc -base64") +/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --test-type --origin-to-force-quic-on=35.189.95.121:4433,localhost:4433 --ignore-certificate-errors-spki-list=$certbase64 +# /Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --test-type --origin-to-force-quic-on=35.189.95.121:4433,localhost:4433 --ignore-certificate-errors --ignore-certificate-errors-spki-list=$certbase64 --use-fake-device-for-media-stream \ No newline at end of file