diff --git a/Cargo.lock b/Cargo.lock index ea0fa279..8a7903c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,6 +96,18 @@ version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +[[package]] +name = "arrayref" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -117,6 +129,28 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "async-trait" version = "0.1.81" @@ -166,6 +200,40 @@ dependencies = [ "thiserror", ] +[[package]] +name = "atrium-streams" +version = "0.1.0" +dependencies = [ + "atrium-api", + "cbor4ii", + "futures", + "ipld-core", + "serde", + "serde_ipld_dagcbor", + "thiserror", +] + +[[package]] +name = "atrium-streams-client" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-stream", + "atrium-streams", + "atrium-xrpc", + "bon", + "futures", + "ipld-core", + "rs-car", + "serde", + "serde_html_form", + "serde_ipld_dagcbor", + "serde_json", + "thiserror", + "tokio", + "tokio-tungstenite", +] + [[package]] name = "atrium-xrpc" version = "0.11.3" @@ -254,6 +322,17 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -263,6 +342,29 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4f37d875011af3196e4828024742a84dcff6b0d027d272f2944f9a99f2c8af" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99b4b686e7ebf76cfa591052482d8c3c8242722518560798631974bf899d5565" +dependencies = [ + "darling", + "ident_case", + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "bsky-cli" version = "0.1.22" @@ -294,7 +396,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "toml", + "toml 0.8.15", "unicode-segmentation", ] @@ -304,6 +406,12 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.6.1" @@ -352,6 +460,19 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "cid" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd94671561e36e4e7de75f753f577edafb0e7c05d6e4547229fdf7938fbcd2c3" +dependencies = [ + "core2", + "multibase", + "multihash 0.18.1", + "serde", + "unsigned-varint 0.7.2", +] + [[package]] name = "cid" version = "0.11.1" @@ -360,7 +481,7 @@ checksum = "3147d8272e8fa0ccd29ce51194dd98f79ddfb8191ba9e3409884e751798acf3a" dependencies = [ "core2", "multibase", - "multihash", + "multihash 0.19.1", "serde", "serde_bytes", "unsigned-varint 0.8.0", @@ -385,7 +506,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.10.0", ] [[package]] @@ -447,6 +568,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation" version = "0.9.4" @@ -540,6 +667,41 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.71", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.71", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -736,6 +898,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -758,6 +921,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.30" @@ -779,6 +953,17 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -797,11 +982,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1082,6 +1272,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1117,7 +1313,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ede82a79e134f179f4b29b5fdb1eb92bd1b38c4dfea394c539051150a21b9b" dependencies = [ - "cid", + "cid 0.11.1", "serde", "serde_bytes", ] @@ -1209,6 +1405,55 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libipld" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1ccd6b8ffb3afee7081fcaec00e1b099fd1c7ccf35ba5729d88538fcc3b4599" +dependencies = [ + "fnv", + "libipld-cbor", + "libipld-core", + "libipld-macro", + "log", + "multihash 0.18.1", + "thiserror", +] + +[[package]] +name = "libipld-cbor" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77d98c9d1747aa5eef1cf099cd648c3fd2d235249f5fed07522aaebc348e423b" +dependencies = [ + "byteorder", + "libipld-core", + "thiserror", +] + +[[package]] +name = "libipld-core" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5acd707e8d8b092e967b2af978ed84709eaded82b75effe6cb6f6cc797ef8158" +dependencies = [ + "anyhow", + "cid 0.10.1", + "core2", + "multibase", + "multihash 0.18.1", + "thiserror", +] + +[[package]] +name = "libipld-macro" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71171c54214f866ae6722f3027f81dff0931e600e5a61e6b1b6a49ca0b5ed4ae" +dependencies = [ + "libipld-core", +] + [[package]] name = "libnghttp2-sys" version = "0.1.10+1.61.0" @@ -1325,6 +1570,17 @@ dependencies = [ "data-encoding-macro", ] +[[package]] +name = "multihash" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfd8a792c1694c6da4f68db0a9d707c72bd260994da179e6030a5dcee00bb815" +dependencies = [ + "core2", + "multihash-derive", + "unsigned-varint 0.7.2", +] + [[package]] name = "multihash" version = "0.19.1" @@ -1336,6 +1592,20 @@ dependencies = [ "unsigned-varint 0.7.2", ] +[[package]] +name = "multihash-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6d4752e6230d8ef7adf7bd5d8c4b1f6561c1014c5ba9a37445ccefe18aa1db" +dependencies = [ + "proc-macro-crate", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", + "synstructure", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -1563,6 +1833,40 @@ dependencies = [ "elliptic-curve", ] +[[package]] +name = "proc-macro-crate" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" +dependencies = [ + "thiserror", + "toml 0.5.11", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -1792,6 +2096,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rs-car" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf69c4017006c0101362b5df74ee230331703e9938f970468dc1e429afe12998" +dependencies = [ + "blake2b_simd", + "futures", + "libipld", + "sha2", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1864,6 +2180,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.18" @@ -2014,6 +2336,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -2108,6 +2441,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -2142,6 +2481,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "unicode-xid", +] + [[package]] name = "tempfile" version = "3.10.1" @@ -2240,6 +2591,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -2253,6 +2618,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + [[package]] name = "toml" version = "0.8.15" @@ -2362,6 +2736,26 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "native-tls", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -2395,6 +2789,12 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +[[package]] +name = "unicode-xid" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a" + [[package]] name = "unsigned-varint" version = "0.7.2" @@ -2424,6 +2824,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index ffacaf5e..6dfe6d71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,8 @@ members = [ "atrium-crypto", "atrium-xrpc", "atrium-xrpc-client", + "atrium-streams", + "atrium-streams-client", "bsky-cli", "bsky-sdk", ] @@ -26,6 +28,8 @@ keywords = ["atproto", "bluesky"] atrium-api = { version = "0.24.4", path = "atrium-api" } atrium-xrpc = { version = "0.11.3", path = "atrium-xrpc" } atrium-xrpc-client = { version = "0.5.6", path = "atrium-xrpc-client" } +atrium-streams = { version = "0.1.0", path = "atrium-streams" } +atrium-streams-client = { version = "0.1.0", path = "atrium-streams-client" } bsky-sdk = { version = "0.1.9", path = "bsky-sdk" } # async in traits @@ -35,6 +39,10 @@ async-trait = "0.1.80" # DAG-CBOR codec ipld-core = { version = "0.4.1", default-features = false, features = ["std"] } serde_ipld_dagcbor = { version = "0.6.0", default-features = false, features = ["std"] } +cbor4ii = { version = "0.2.14", default-features = false } + +# CAR files +rs-car = "0.4.1" # Parsing and validation chrono = "0.4" @@ -55,8 +63,10 @@ rand = "0.8.5" # Networking futures = { version = "0.3.30", default-features = false, features = ["alloc"] } +async-stream = "0.3.5" http = "1.1.0" tokio = { version = "1.37", default-features = false } +tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } # HTTP client integrations isahc = "1.7.2" @@ -76,3 +86,6 @@ mockito = "1.4" # WebAssembly wasm-bindgen-test = "0.3.41" bumpalo = "~3.14.0" + +# Code generation +bon = "2.2.1" \ No newline at end of file diff --git a/README.md b/README.md index 0822a1f0..4f2d7da6 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,14 @@ Definitions for XRPC request/response, and their associated errors. A library provides clients that implement the `XrpcClient` defined in [atrium-xrpc](./atrium-xrpc/) +### [`atrium-streams`](./atrium-streams/) + +Definitions for traits, types and utilities for dealing with event stream subscriptions. (WIP) + +### [`atrium-streams-client`](./atrium-streams-client/) + +A library that provides default implementations of the `EventStreamClient`, `Handlers` and `Subscription` defined in [atrium-streams](./atrium-streams/) for interacting with the variety of subscriptions in ATProto (WIP) + ### [`bsky-sdk`](./bsky-sdk/) [![](https://img.shields.io/crates/v/bsky-sdk)](https://crates.io/crates/bsky-sdk) diff --git a/atrium-streams-client/.gitignore b/atrium-streams-client/.gitignore new file mode 100644 index 00000000..4fffb2f8 --- /dev/null +++ b/atrium-streams-client/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/atrium-streams-client/CHANGELOG.md b/atrium-streams-client/CHANGELOG.md new file mode 100644 index 00000000..df3cff36 --- /dev/null +++ b/atrium-streams-client/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). \ No newline at end of file diff --git a/atrium-streams-client/Cargo.toml b/atrium-streams-client/Cargo.toml new file mode 100644 index 00000000..2c399bda --- /dev/null +++ b/atrium-streams-client/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "atrium-streams-client" +version = "0.1.0" +authors = ["Elaina <17bestradiol@proton.me>"] +edition.workspace = true +rust-version.workspace = true +description = "Event Streams Client library for AT Protocol (Bluesky)" +documentation = "https://docs.rs/atrium-streams-client" +readme = "README.md" +repository.workspace = true +license.workspace = true +keywords.workspace = true + +[dependencies] +atrium-xrpc.workspace = true +atrium-streams.workspace = true +futures.workspace = true +ipld-core.workspace = true +async-stream.workspace = true +tokio-tungstenite.workspace = true +serde_ipld_dagcbor.workspace = true +rs-car.workspace = true +tokio.workspace = true +bon.workspace = true +serde_html_form.workspace = true +serde.workspace = true +thiserror.workspace = true + +[dev-dependencies] +anyhow.workspace = true +serde_json.workspace = true +tokio = { version = "1.37", default-features = false, features = ["rt-multi-thread"] } \ No newline at end of file diff --git a/atrium-streams-client/README.md b/atrium-streams-client/README.md new file mode 100644 index 00000000..5e919be9 --- /dev/null +++ b/atrium-streams-client/README.md @@ -0,0 +1 @@ +# ATrium XRPC WSS Client \ No newline at end of file diff --git a/atrium-streams-client/src/client/mod.rs b/atrium-streams-client/src/client/mod.rs new file mode 100644 index 00000000..29e694e7 --- /dev/null +++ b/atrium-streams-client/src/client/mod.rs @@ -0,0 +1,101 @@ +//! This file provides a client for the `ATProto` XRPC over WSS protocol. +//! It implements the [`EventStreamClient`] trait for the [`WssClient`] struct. + +#[cfg(test)] +mod tests; + +use std::str::FromStr; + +use futures::Stream; +use tokio::net::TcpStream; + +use atrium_xrpc::{ + http::{Request, Uri}, + types::Header, +}; +use bon::Builder; +use serde::Serialize; +use tokio_tungstenite::{ + connect_async, + tungstenite::{self, handshake::client::generate_key}, + MaybeTlsStream, WebSocketStream, +}; + +use atrium_streams::client::EventStreamClient; + +/// An enum of possible error kinds for this crate. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Invalid uri")] + InvalidUri, + #[error("Parsing parameters failed: {0}")] + ParsingParameters(#[from] serde_html_form::ser::Error), + #[error("Connection error: {0}")] + Connection(#[from] tungstenite::Error), +} + +#[derive(Builder)] +pub struct WssClient { + params: Option

, +} + +type StreamKind = WebSocketStream>; +impl EventStreamClient<::Item, Error> + for WssClient

+{ + async fn connect( + &self, + mut uri: String, + ) -> Result::Item>, Error> { + let Self { params } = self; + + // Query parameters + if let Some(p) = ¶ms { + uri.push('?'); + uri += &serde_html_form::to_string(p)?; + }; + + // Request + let (uri, host) = get_host(&uri)?; + let request = gen_request(self, &uri, &*host).await?; + + // Connection + let (stream, _) = connect_async(request).await?; + Ok(stream) + } +} + +/// Extract the URI and host from a string. +fn get_host(uri: &str) -> Result<(Uri, Box), Error> { + let uri = Uri::from_str(uri).map_err(|_| Error::InvalidUri)?; + let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str(); + let host = authority.find('@').map_or_else(|| authority, |idx| authority.split_at(idx + 1).1); + let host = Box::from(host); + Ok((uri, host)) +} + +/// Generate a request for the given URI and host. +/// It sets the necessary headers for a WebSocket connection, +/// plus the client's `AtprotoProxy` and `AtprotoAcceptLabelers` headers. +async fn gen_request( + client: &WssClient

, + uri: &Uri, + host: &str, +) -> Result, Error> { + let mut request = Request::builder() + .uri(uri) + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()); + if let Some(proxy) = client.atproto_proxy_header().await { + request = request.header(Header::AtprotoProxy, proxy); + } + if let Some(accept_labelers) = client.atproto_accept_labelers_header().await { + request = request.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); + } + let request = request.body(()).map_err(|_| Error::InvalidUri)?; + Ok(request) +} diff --git a/atrium-streams-client/src/client/tests.rs b/atrium-streams-client/src/client/tests.rs new file mode 100644 index 00000000..d7265ab3 --- /dev/null +++ b/atrium-streams-client/src/client/tests.rs @@ -0,0 +1,94 @@ +use std::net::{Ipv4Addr, SocketAddr}; + +use atrium_streams::{atrium_api::com::atproto::sync::subscribe_repos, client::EventStreamClient}; +use atrium_xrpc::http::{header::SEC_WEBSOCKET_KEY, HeaderMap, HeaderValue}; +use futures::{SinkExt, StreamExt}; +use tokio::{ + net::{TcpListener, TcpStream}, + runtime::Runtime, +}; +use tokio_tungstenite::{ + tungstenite::{ + handshake::server::{ErrorResponse, Request, Response}, + Message, + }, + WebSocketStream, +}; + +use crate::WssClient; + +use super::{gen_request, get_host}; + +#[test] +fn client() { + let fut = async { + let ipv4 = Ipv4Addr::LOCALHOST.to_string(); + let xrpc_uri = format!("ws://{ipv4}:3000/xrpc/{}", subscribe_repos::NSID); + let (client, mut client_headers) = wss_client(&xrpc_uri).await; + + let server_handle = tokio::spawn(mock_wss_server()); + let mut client_stream = client.connect(xrpc_uri).await.unwrap(); + let (server_stream, mut server_headers, route) = server_handle.await.unwrap(); + + assert_eq!(route, format!("/xrpc/{}", subscribe_repos::NSID)); + + client_headers.remove(SEC_WEBSOCKET_KEY); + server_headers.remove(SEC_WEBSOCKET_KEY); + assert_eq!(client_headers, server_headers); + + let (mut inbound, _) = server_stream.split(); + inbound.send(Message::text("test_message")).await.unwrap(); + let msg = client_stream.next().await.unwrap().unwrap(); + assert_eq!(msg, Message::text("test_message")); + }; + Runtime::new().unwrap().block_on(fut); +} + +async fn wss_client( + uri: &str, +) -> (WssClient, HeaderMap) { + let params = subscribe_repos::ParametersData { cursor: None }; + + let client = WssClient::builder().params(params).build(); + + let (uri, host) = get_host(uri).unwrap(); + let req = gen_request(&client, &uri, &host).await.unwrap(); + let headers = req.headers(); + + (client, headers.clone()) +} + +async fn mock_wss_server() -> (WebSocketStream, HeaderMap, String) { + let sock_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 3000)); + + let listener = TcpListener::bind(sock_addr).await.expect("Failed to bind to port!"); + + let headers: HeaderMap; + let route: String; + let (stream, _) = listener.accept().await.unwrap(); + let (headers_, route_, stream) = extract_headers(stream).await; + headers = headers_; + route = route_; + + (stream, headers, route) +} + +async fn extract_headers( + raw_stream: TcpStream, +) -> (HeaderMap, String, WebSocketStream) { + let mut headers: Option> = None; + let mut route: Option = None; + + let copy_headers_callback = + |request: &Request, response: Response| -> Result { + headers = Some(request.headers().clone()); + route = Some(request.uri().path().to_owned()); + Ok(response) + }; + + let stream = tokio_tungstenite::accept_hdr_async(raw_stream, copy_headers_callback) + .await + .expect("Error during the websocket handshake occurred"); + + (headers.unwrap(), route.unwrap(), stream) +} diff --git a/atrium-streams-client/src/lib.rs b/atrium-streams-client/src/lib.rs new file mode 100644 index 00000000..f0ee3e7b --- /dev/null +++ b/atrium-streams-client/src/lib.rs @@ -0,0 +1,6 @@ +mod client; +pub use client::{Error, WssClient}; + +pub mod subscriptions; + +pub use atrium_streams; // Re-export the atrium_streams crate diff --git a/atrium-streams-client/src/subscriptions/mod.rs b/atrium-streams-client/src/subscriptions/mod.rs new file mode 100644 index 00000000..21b552a0 --- /dev/null +++ b/atrium-streams-client/src/subscriptions/mod.rs @@ -0,0 +1 @@ +pub mod repositories; diff --git a/atrium-streams-client/src/subscriptions/repositories/firehose.rs b/atrium-streams-client/src/subscriptions/repositories/firehose.rs new file mode 100644 index 00000000..e07bde2f --- /dev/null +++ b/atrium-streams-client/src/subscriptions/repositories/firehose.rs @@ -0,0 +1,271 @@ +use std::{collections::BTreeMap, io::Cursor}; + +use futures::io::Cursor as FutCursor; +use ipld_core::cid::Cid; + +use super::type_defs::{self, Operation}; +use atrium_streams::{ + atrium_api::{ + com::atproto::sync::subscribe_repos::{ + self, AccountData, CommitData, HandleData, IdentityData, InfoData, MigrateData, + RepoOpData, TombstoneData, + }, + record::KnownRecord, + types::Object, + }, + subscriptions::{ + handlers::repositories::{HandledData, Handler, ProcessedData}, + ConnectionHandler, ProcessedPayload, + }, +}; + +/// Errors for this crate +#[derive(Debug, thiserror::Error)] +pub enum HandlingError { + #[error("CAR Decoding error: {0}")] + CarDecoding(#[from] rs_car::CarDecodeError), + #[error("IPLD Decoding error: {0}")] + IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), +} + +#[derive(bon::Builder)] +pub struct Firehose { + #[builder(default)] + enable_commit: bool, + #[builder(default)] + enable_identity: bool, + #[builder(default)] + enable_account: bool, + #[builder(default)] + enable_handle: bool, + #[builder(default)] + enable_migrate: bool, + #[builder(default)] + enable_tombstone: bool, + #[builder(default)] + enable_info: bool, +} +impl ConnectionHandler for Firehose { + type HandledData = HandledData; + type HandlingError = self::HandlingError; + + async fn handle_payload( + &self, + t: String, + payload: Vec, + ) -> Result>, Self::HandlingError> { + let res = match t.as_str() { + "#commit" => { + if self.enable_commit { + self.process_commit(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Commit)) + } else { + None + } + } + "#identity" => { + if self.enable_identity { + self.process_identity(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Identity)) + } else { + None + } + } + "#account" => { + if self.enable_account { + self.process_account(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Account)) + } else { + None + } + } + "#handle" => { + if self.enable_handle { + self.process_handle(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Handle)) + } else { + None + } + } + "#migrate" => { + if self.enable_migrate { + self.process_migrate(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Migrate)) + } else { + None + } + } + "#tombstone" => { + if self.enable_tombstone { + self.process_tombstone(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Tombstone)) + } else { + None + } + } + "#info" => { + if self.enable_info { + self.process_info(serde_ipld_dagcbor::from_reader(payload.as_slice())?) + .await? + .map(|data| data.map(ProcessedData::Info)) + } else { + None + } + } + _ => { + // "Clients should ignore frames with headers that have unknown op or t values. + // Unknown fields in both headers and payloads should be ignored." + // https://atproto.com/specs/event-stream + None + } + }; + + Ok(res) + } +} + +impl Handler for Firehose { + type ProcessedCommitData = type_defs::ProcessedCommitData; + async fn process_commit( + &self, + payload: subscribe_repos::Commit, + ) -> Result>, Self::HandlingError> { + let CommitData { blobs, blocks, commit, ops, repo, rev, seq, since, time, too_big, .. } = + payload.data; + + // If it is too big, the blocks and ops are not sent, so we skip the processing. + let ops_opt = if too_big { + None + } else { + // We read all the blocks from the CAR file and store them in a map + // so that we can look up the data for each operation by its CID. + let mut cursor = FutCursor::new(blocks); + let mut map = rs_car::car_read_all(&mut cursor, true) + .await? + .0 + .into_iter() + .map(compat_cid) + .collect::>(); + + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + Some(process_ops(ops, &mut map)?) + }; + + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedCommitData { ops: ops_opt, blobs, commit, repo, rev, since, time }, + })) + } + + type ProcessedIdentityData = type_defs::ProcessedIdentityData; + async fn process_identity( + &self, + payload: subscribe_repos::Identity, + ) -> Result>, Self::HandlingError> { + let IdentityData { did, handle, seq, time } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedIdentityData { did, handle, time }, + })) + } + + type ProcessedAccountData = type_defs::ProcessedAccountData; + async fn process_account( + &self, + payload: subscribe_repos::Account, + ) -> Result>, Self::HandlingError> { + let AccountData { did, seq, time, active, status } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedAccountData { did, active, status, time }, + })) + } + + type ProcessedHandleData = type_defs::ProcessedHandleData; + async fn process_handle( + &self, + payload: subscribe_repos::Handle, + ) -> Result>, Self::HandlingError> { + let HandleData { did, handle, seq, time } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedHandleData { did, handle, time }, + })) + } + + type ProcessedMigrateData = type_defs::ProcessedMigrateData; + async fn process_migrate( + &self, + payload: subscribe_repos::Migrate, + ) -> Result>, Self::HandlingError> { + let MigrateData { did, migrate_to, seq, time } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedMigrateData { did, migrate_to, time }, + })) + } + + type ProcessedTombstoneData = type_defs::ProcessedTombstoneData; + async fn process_tombstone( + &self, + payload: subscribe_repos::Tombstone, + ) -> Result>, Self::HandlingError> { + let TombstoneData { did, seq, time } = payload.data; + Ok(Some(ProcessedPayload { + seq: Some(seq), + data: Self::ProcessedTombstoneData { did, time }, + })) + } + + type ProcessedInfoData = InfoData; + async fn process_info( + &self, + payload: subscribe_repos::Info, + ) -> Result>, Self::HandlingError> { + Ok(Some(ProcessedPayload { seq: None, data: payload.data })) + } +} + +// Transmute is here because the version of the `rs_car` crate for `cid` is 0.10.1 whereas +// the `ilpd_core` crate is 0.11.1. Should work regardless, given that the Cid type's +// memory layout was not changed between the two versions. Temporary fix. +// TODO: Find a better way to fix the version compatibility issue. +fn compat_cid((cid, item): (rs_car::Cid, Vec)) -> (ipld_core::cid::Cid, Vec) { + (unsafe { std::mem::transmute::<_, Cid>(cid) }, item) +} + +fn process_ops( + ops: Vec>, + map: &mut BTreeMap>, +) -> Result, serde_ipld_dagcbor::DecodeError> { + let mut processed_ops = Vec::with_capacity(ops.len()); + for op in ops { + processed_ops.push(process_op(map, op)?); + } + Ok(processed_ops) +} + +/// Processes a single operation. +fn process_op( + map: &mut BTreeMap>, + op: Object, +) -> Result> { + let RepoOpData { action, path, cid } = op.data; + + // Finds in the map the `Record` with the operation's CID and deserializes it. + // If the item is not found, returns `None`. + let record = match cid.as_ref().and_then(|c| map.get_mut(&c.0)) { + Some(item) => Some(serde_ipld_dagcbor::from_reader::(Cursor::new(item))?), + None => None, + }; + + Ok(Operation { action, path, record }) +} diff --git a/atrium-streams-client/src/subscriptions/repositories/mod.rs b/atrium-streams-client/src/subscriptions/repositories/mod.rs new file mode 100644 index 00000000..48e9f2ad --- /dev/null +++ b/atrium-streams-client/src/subscriptions/repositories/mod.rs @@ -0,0 +1,115 @@ +pub mod firehose; +pub mod type_defs; + +#[cfg(test)] +mod tests; + +use std::marker::PhantomData; + +use async_stream::stream; +use bon::bon; +use futures::{Stream, StreamExt}; +use tokio_tungstenite::tungstenite::Message; + +use atrium_streams::{ + atrium_api::com::atproto::sync::subscribe_repos, + subscriptions::{ + frames::{self, Frame}, + ConnectionHandler, ProcessedPayload, Subscription, SubscriptionError, + }, +}; + +/// A struct that represents the repositories subscription, used in `com.atproto.sync.subscribeRepos`. +pub struct Repositories { + /// This is only here to constrain the `ConnectionPayload` used in [`Subscription`], or else we get a compile error. + _payload_kind: PhantomData, +} +#[bon] +impl Repositories +where + Self: Subscription, +{ + #[builder] + /// Defines the builder for any generic `Repositories` struct that implements [`Subscription`]. + pub fn new( + connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream< + Item = Result, SubscriptionError>, + > { + Self::handle_connection(connection, handler) + } +} + +type WssResult = tokio_tungstenite::tungstenite::Result; +impl Subscription for Repositories { + fn handle_connection( + mut connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream< + Item = Result, SubscriptionError>, + > { + // Builds a new async stream that will deserialize the packets sent through the + // TCP tunnel and then yield the results processed by the handler back to the caller. + let stream = stream! { + loop { + match connection.next().await { + None => break, // Server dropped connection + Some(Err(e)) => { // WebSocket error + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid packet. Error: {e:?}"))); + break; + } + Some(Ok(Message::Binary(data))) => { + match Frame::try_from(data) { + Ok(Frame::Message { t, data: payload }) => { + match handler.handle_payload(t, payload).await { + Ok(Some(res)) => yield Ok(res), // Payload was successfully handled. + Ok(None) => {}, // Payload was ignored by Handler. + Err(e) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid payload. Error: {e:?}"))); + break; + }, + } + }, + Ok(Frame::Error { data }) => { + yield match serde_ipld_dagcbor::from_reader::(data.as_slice()) { + Ok(e) => Err(SubscriptionError::Other(e)), + Err(e) => Err(SubscriptionError::Unknown(format!("Failed to decode error frame: {e:?}"))), + }; + break; + }, + Err(frames::Error::EmptyPayload(ipld)) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard frames::errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received empty payload for header: {ipld:?}"))); + break; + }, + Err(frames::Error::IpldDecoding(e)) => { + // "Invalid framing or invalid DAG-CBOR encoding are hard errors, + // and the client should drop the entire connection instead of skipping the frame." + // https://atproto.com/specs/event-stream + yield Err(SubscriptionError::Abort(format!("Received invalid frame. Error: {e:?}"))); + break; + }, + Err(frames::Error::UnknownFrameType(_)) => { + // "Clients should ignore frames with headers that have unknown op or t values. + // Unknown fields in both headers and payloads should be ignored." + // https://atproto.com/specs/event-stream + }, + } + } + _ => {}, // Ignore other message types. + } + } + }; + + Box::pin(stream) + } +} diff --git a/atrium-streams-client/src/subscriptions/repositories/tests.rs b/atrium-streams-client/src/subscriptions/repositories/tests.rs new file mode 100644 index 00000000..d94dd78a --- /dev/null +++ b/atrium-streams-client/src/subscriptions/repositories/tests.rs @@ -0,0 +1,461 @@ +use std::{convert::identity as id, vec}; + +use atrium_streams::{ + atrium_api::{ + com::atproto::{ + label::subscribe_labels::InfoData, + sync::subscribe_repos::{ + self, AccountData, CommitData, HandleData, IdentityData, MigrateData, TombstoneData, + }, + }, + types::{ + string::{Datetime, Did, Handle}, + CidLink, Object, + }, + }, + subscriptions::{ + handlers::repositories::HandledData, ConnectionHandler, ProcessedPayload, SubscriptionError, + }, +}; +use futures::{executor::block_on_stream, Stream}; +use ipld_core::{ + cid::{multihash::Multihash, Cid}, + ipld::Ipld, +}; +use serde_json::Value; +use tokio_tungstenite::tungstenite::{Error, Message}; + +use super::{firehose::Firehose, Repositories}; + +fn serialize_ipld(frame: &str) -> Result, anyhow::Error> { + if frame.is_empty() { + return Ok(vec![]); + } + + let json: Value = serde_json::from_str(frame)?; + let bytes = serde_ipld_dagcbor::to_vec(&json)?; + Ok(bytes) +} + +fn mock_connection<'a>( + packets: Vec<(&'a str, &'a str)>, +) -> impl Stream> + Unpin + 'a { + let mut stream = packets.into_iter().map(|(header, payload)| { + // Using Utf8 as an arbitrary tungstenite error + serialize_ipld(header) + .map(|mut v| { + serialize_ipld(payload) + .map(|mut p| { + Message::Binary({ + v.append(&mut p); + v + }) + }) + .map_err(|_| Error::Utf8) + }) + .map_err(|_| Error::Utf8) + .and_then(id) + }); + let connection = async_stream::stream! { + while let Some(packet) = stream.next() { + yield packet; + } + }; + + Box::pin(connection) +} + +fn test_packet( + packet: Option<(&str, &str)>, +) -> Option, HandledData), SubscriptionError>> +{ + let connection = mock_connection(if let Some(packet) = packet { vec![packet] } else { vec![] }); + + let subscription = gen_default_subscription(connection); + + block_on_stream(subscription) + .next() + .map(|v| v.map(|ProcessedPayload { data, seq }| (seq, data))) +} + +fn gen_default_subscription( + connection: impl Stream> + Unpin, +) -> impl Stream< + Item = Result< + ProcessedPayload<::HandledData>, + SubscriptionError, + >, +> { + let firehose = Firehose::builder() + .enable_commit(true) + .enable_identity(true) + .enable_account(true) + .enable_handle(true) + .enable_migrate(true) + .enable_tombstone(true) + .enable_info(true) + .build(); + let subscription = Repositories::builder().connection(connection).handler(firehose).build(); + subscription +} + +#[test] +fn disconnect() { + if test_packet(None).is_none() { + return; + } + panic!("Expected None") +} + +#[test] +fn invalid_packet() { + if let SubscriptionError::Abort(msg) = + test_packet(Some(("{ not-a-header }", "{ not-a-payload }"))).unwrap().unwrap_err() + { + assert_eq!(msg, "Received invalid packet. Error: Utf8"); + return; + } + panic!("Expected Invalid Packet") +} + +#[test] +fn commit() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: Some(CommitData { + blobs: vec![], + blocks: vec![], + commit: CidLink(Cid::new_v1(0x70, Multihash::<64>::wrap(0x12, &[0; 64]).unwrap())), + ops: vec![], + prev: None, + rebase: false, + repo: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + rev: String::new(), + seq: 99, + since: None, + time: now, + too_big: true, + }), + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Commit(ProcessedCommitData {{ \ + repo: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + commit: CidLink(Cid(bafybeqaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa)), \ + ops: None, \ + blobs: [], \ + rev: \"\", \ + since: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn identity() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: IdentityData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + handle: None, + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#identity" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Identity(ProcessedIdentityData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + handle: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn account() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: AccountData { + active: false, + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + seq: 99, + status: None, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#account" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Account(ProcessedAccountData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + active: false, \ + status: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn handle() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: HandleData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + handle: Handle::new("test.handle.xyz".to_string()).unwrap(), + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#handle" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Handle(ProcessedHandleData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + handle: Handle(\"test.handle.xyz\"), \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn migrate() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: MigrateData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + migrate_to: None, + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#migrate" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Migrate(ProcessedMigrateData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + migrate_to: None, \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn tombstone() { + let now = Datetime::now(); + let now_str = format!("{:?}", now); + let body = Object { + data: TombstoneData { + did: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + seq: 99, + time: now, + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#tombstone" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, Some(99)); + assert_eq!( + format!("{:?}", data), + format!( + "Tombstone(ProcessedTombstoneData {{ \ + did: Did(\"did:plc:ewvi7nxzyoun6zhxrhs64oiz\"), \ + time: {now_str} \ + }})" + ) + ); +} + +#[test] +fn info() { + let body = Object { + data: InfoData { + message: Some("Requested cursor exceeded limit. Possibly missing events".to_string()), + name: "OutdatedCursor".to_string(), + }, + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + let (seq, data) = + test_packet(Some((r##"{ "op": 1, "t": "#info" }"##, &body))).unwrap().unwrap(); + assert_eq!(seq, None); + assert_eq!( + format!("{:?}", data), + "Info(InfoData { \ + message: Some(\"Requested cursor exceeded limit. Possibly missing events\"), \ + name: \"OutdatedCursor\" \ + })" + .to_string() + ); +} + +#[test] +fn ignored_frame() { + if test_packet(Some((r##"{ "op": 1, "t": "#non-existent" }"##, r#"{ "foo": "bar" }"#))) + .is_none() + { + return; + } + panic!("Expected None") +} + +#[test] +fn invalid_body() { + let body = Object { + data: Some(CommitData { + blobs: vec![], + blocks: vec![1], // Invalid CAR file + commit: CidLink(Cid::new_v1(0x70, Multihash::<64>::wrap(0x12, &[0; 64]).unwrap())), + ops: vec![], + prev: None, + rebase: false, + repo: Did::new("did:plc:ewvi7nxzyoun6zhxrhs64oiz".to_string()).unwrap(), + rev: String::new(), + seq: 0, + since: None, + time: Datetime::now(), + too_big: false, + }), + extra_data: Ipld::Null, + }; + let body = serde_json::to_string(&body).unwrap(); + if let SubscriptionError::Abort(msg) = + test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, &body))).unwrap().unwrap_err() + { + assert_eq!( + msg, + "Received invalid payload. Error: CarDecoding(IoError(Kind(UnexpectedEof)))" + ); + return; + } +} + +#[test] +fn future_cursor() { + let res = test_packet(Some(( + r##"{ "op": -1 }"##, + r#"{ "error": "FutureCursor", "message": "Cursor in the future." }"#, + ))); + if let SubscriptionError::Other(subscribe_repos::Error::FutureCursor(Some(s))) = + res.unwrap().unwrap_err() + { + assert_eq!("Cursor in the future.", &s); + return; + } + panic!("Expected FutureCursor") +} + +#[test] +fn consumer_too_slow() { + let res = test_packet(Some(( + r##"{ "op": -1 }"##, + r#"{ "error": "ConsumerTooSlow", "message": "Stream consumer too slow" }"#, + ))); + if let SubscriptionError::Other(subscribe_repos::Error::ConsumerTooSlow(Some(s))) = + res.unwrap().unwrap_err() + { + assert_eq!("Stream consumer too slow", &s); + return; + } + panic!("Expected ConsumerTooSlow") +} + +#[test] +fn unknown_error() { + let res = test_packet(Some(( + r##"{ "op": -1 }"##, + r#"{ "error": "Unknown", "message": "No one knows" }"#, + ))); + if let SubscriptionError::Unknown(msg) = res.unwrap().unwrap_err() { + assert_eq!( + "Failed to decode error frame: \ + Msg(\"unknown variant `Unknown`, expected `FutureCursor` or `ConsumerTooSlow`\")", + &msg + ); + return; + } + panic!("Expected Unknown") +} + +#[test] +fn empty_payload() { + let res = test_packet(Some((r##"{ "op": 1, "t": "#commit" }"##, r#""#))); + if let SubscriptionError::Abort(msg) = res.unwrap().unwrap_err() { + assert_eq!("Received empty payload for header: {\"op\": 1, \"t\": \"#commit\"}", &msg); + return; + } + panic!("Expected Empty Payload") +} + +#[test] +fn invalid_frame() { + fn mock_invalid() -> impl Stream> + Unpin { + let mut stream = vec![Message::Binary(vec![b'{'])].into_iter(); + let connection = async_stream::stream! { + while let Some(packet) = stream.next() { + yield Ok(packet); + } + }; + Box::pin(connection) + } + + let subscription = gen_default_subscription(mock_invalid()); + + let res = block_on_stream(subscription) + .next() + .map(|v| v.map(|ProcessedPayload { data, seq }| (seq, data))); + + if let SubscriptionError::Abort(msg) = res.unwrap().unwrap_err() { + assert_eq!("Received invalid frame. Error: Eof", &msg); + return; + } + panic!("Expected Invalid Frame") +} + +#[test] +fn unknown_frame() { + let res = test_packet(Some((r##"{ "op": 2 }"##, r#"{ "unknown": "header" }"#))); + if res.is_none() { + return; + } + panic!("Expected None") +} diff --git a/atrium-streams-client/src/subscriptions/repositories/type_defs.rs b/atrium-streams-client/src/subscriptions/repositories/type_defs.rs new file mode 100644 index 00000000..ce9b272c --- /dev/null +++ b/atrium-streams-client/src/subscriptions/repositories/type_defs.rs @@ -0,0 +1,74 @@ +//! This file defines the types used in the Firehose handler. + +use atrium_streams::atrium_api::{ + record::KnownRecord, + types::{ + string::{Datetime, Did, Handle}, + CidLink, + }, +}; + +// region: Commit +#[derive(Debug)] +pub struct ProcessedCommitData { + pub repo: Did, + pub commit: CidLink, + // `ops` can be `None` if the commit is marked as `too_big`. + pub ops: Option>, + pub blobs: Vec, + pub rev: String, + pub since: Option, + pub time: Datetime, +} +#[derive(Debug)] +pub struct Operation { + pub action: String, + pub path: String, + pub record: Option, +} +// endregion: Commit + +// region: Identity +#[derive(Debug)] +pub struct ProcessedIdentityData { + pub did: Did, + pub handle: Option, + pub time: Datetime, +} +// endregion: Identity + +// region: Account +#[derive(Debug)] +pub struct ProcessedAccountData { + pub did: Did, + pub active: bool, + pub status: Option, + pub time: Datetime, +} +// endregion: Account + +// region: Handle +#[derive(Debug)] +pub struct ProcessedHandleData { + pub did: Did, + pub handle: Handle, + pub time: Datetime, +} +// endregion: Handle + +// region: Migrate +#[derive(Debug)] +pub struct ProcessedMigrateData { + pub did: Did, + pub migrate_to: Option, + pub time: Datetime, +} +// endregion: Migrate + +// region: Tombstone +#[derive(Debug)] +pub struct ProcessedTombstoneData { + pub did: Did, + pub time: Datetime, +} +// endregion: Tombstone diff --git a/atrium-streams/.gitignore b/atrium-streams/.gitignore new file mode 100644 index 00000000..4fffb2f8 --- /dev/null +++ b/atrium-streams/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/atrium-streams/CHANGELOG.md b/atrium-streams/CHANGELOG.md new file mode 100644 index 00000000..df3cff36 --- /dev/null +++ b/atrium-streams/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). \ No newline at end of file diff --git a/atrium-streams/Cargo.toml b/atrium-streams/Cargo.toml new file mode 100644 index 00000000..d1f7453d --- /dev/null +++ b/atrium-streams/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "atrium-streams" +version = "0.1.0" +authors = ["Elaina <17bestradiol@proton.me>"] +edition.workspace = true +rust-version.workspace = true +description = "Event Streams library for AT Protocol (Bluesky)" +documentation = "https://docs.rs/atrium-streams" +readme = "README.md" +repository.workspace = true +license.workspace = true +keywords.workspace = true + +[dependencies] +atrium-api.workspace = true +futures.workspace = true +cbor4ii.workspace = true +ipld-core.workspace = true +serde.workspace = true +serde_ipld_dagcbor.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/atrium-streams/README.md b/atrium-streams/README.md new file mode 100644 index 00000000..0ada3f05 --- /dev/null +++ b/atrium-streams/README.md @@ -0,0 +1 @@ +# ATrium XRPC WSS \ No newline at end of file diff --git a/atrium-streams/src/client.rs b/atrium-streams/src/client.rs new file mode 100644 index 00000000..5226c194 --- /dev/null +++ b/atrium-streams/src/client.rs @@ -0,0 +1,25 @@ +use std::future::Future; + +use futures::Stream; + +/// An abstract WSS client. +pub trait EventStreamClient { + /// Send an XRPC request. + /// + /// # Returns + /// [`Result`] + fn connect( + &self, + uri: String, + ) -> impl Future, ConnectionError>> + Send; + + /// Get the `atproto-proxy` header. + fn atproto_proxy_header(&self) -> impl Future> + Send { + async { None } + } + + /// Get the `atproto-accept-labelers` header. + fn atproto_accept_labelers_header(&self) -> impl Future>> + Send { + async { None } + } +} diff --git a/atrium-streams/src/lib.rs b/atrium-streams/src/lib.rs new file mode 100644 index 00000000..8e08c2a6 --- /dev/null +++ b/atrium-streams/src/lib.rs @@ -0,0 +1,4 @@ +pub mod client; +pub mod subscriptions; + +pub use atrium_api; // Re-export the atrium_api crate diff --git a/atrium-streams/src/subscriptions/frames/mod.rs b/atrium-streams/src/subscriptions/frames/mod.rs new file mode 100644 index 00000000..e332721a --- /dev/null +++ b/atrium-streams/src/subscriptions/frames/mod.rs @@ -0,0 +1,81 @@ +//! This file defines the [`FrameHeader`] and [`Frame`] types, which are used to parse the payloads sent by the subscription through the event stream. +//! You can read more about the specs for these types in the [`ATProto documentation`](https://atproto.com/specs/event-stream) + +#[cfg(test)] +mod tests; + +use cbor4ii::core::utils::IoReader; +use ipld_core::ipld::Ipld; +use serde::Deserialize; +use serde_ipld_dagcbor::de::Deserializer; +use std::io::Cursor; + +/// An error type for this crate. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Unknown frame type. Header: {0:?}")] + UnknownFrameType(Ipld), + #[error("Payload was empty. Header: {0:?}")] + EmptyPayload(Ipld), + #[error("Ipld Decoding error: {0}")] + IpldDecoding(#[from] serde_ipld_dagcbor::DecodeError), +} + +/// Represents the header of a frame. It's the first [`Ipld`] object in a Binary payload sent by a subscription. +#[derive(Debug, Clone, PartialEq, Eq)] +enum FrameHeader { + Message { t: String }, + Error, +} + +impl TryFrom for FrameHeader { + type Error = self::Error; + + fn try_from(header: Ipld) -> Result>::Error> { + if let Ipld::Map(ref map) = header { + if let Some(Ipld::Integer(i)) = map.get("op") { + match i { + 1 => { + if let Some(Ipld::String(s)) = map.get("t") { + return Ok(Self::Message { t: s.to_owned() }); + } + } + -1 => return Ok(Self::Error), + _ => {} + } + } + } + Err(Error::UnknownFrameType(header)) + } +} + +/// Represents a frame sent by a subscription. It's the second [`Ipld`] object in a Binary payload sent by a subscription. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Frame { + Message { t: String, data: Vec }, + Error { data: Vec }, +} + +impl TryFrom> for Frame { + type Error = self::Error; + + fn try_from(value: Vec) -> Result>>::Error> { + let mut cursor = Cursor::new(value); + let mut deserializer = Deserializer::from_reader(IoReader::new(&mut cursor)); + let header = Deserialize::deserialize(&mut deserializer)?; + + // Error means the stream did not end (trailing data), which implies a second IPLD (in this case, the payload). + // If the stream ended, the payload is empty, in which case we error. + let data = if deserializer.end().is_err() { + let pos = cursor.position() as usize; + cursor.get_mut().drain(pos..).collect() + } else { + return Err(Error::EmptyPayload(header)); + }; + + match FrameHeader::try_from(header)? { + FrameHeader::Message { t } => Ok(Self::Message { t, data }), + FrameHeader::Error => Ok(Self::Error { data }), + } + } +} diff --git a/atrium-streams/src/subscriptions/frames/tests.rs b/atrium-streams/src/subscriptions/frames/tests.rs new file mode 100644 index 00000000..10804549 --- /dev/null +++ b/atrium-streams/src/subscriptions/frames/tests.rs @@ -0,0 +1,56 @@ +use super::*; + +fn serialized_data(s: &str) -> Vec { + assert!(s.len() % 2 == 0); + let b2u = |b: u8| match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'f' => b - b'a' + 10, + _ => unreachable!(), + }; + s.as_bytes().chunks(2).map(|b| (b2u(b[0]) << 4) + b2u(b[1])).collect() +} + +#[test] +fn deserialize_message_frame_header() { + // {"op": 1, "t": "#commit"} + let data = serialized_data("a2626f700161746723636f6d6d6974"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect("failed to deserialize"), + FrameHeader::Message { t: String::from("#commit") } + ); +} + +#[test] +fn deserialize_error_frame_header() { + // {"op": -1} + let data = serialized_data("a1626f7020"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!(result.expect("failed to deserialize"), FrameHeader::Error); +} + +#[test] +fn deserialize_invalid_frame_header() { + { + // {"op": 2, "t": "#commit"} + let data = serialized_data("a2626f700261746723636f6d6d6974"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect_err("must be failed").to_string(), + "Unknown frame type. Header: {\"op\": 2, \"t\": \"#commit\"}" + ); + } + { + // {"op": -2} + let data = serialized_data("a1626f7021"); + let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); + let result = FrameHeader::try_from(ipld); + assert_eq!( + result.expect_err("must be failed").to_string(), + "Unknown frame type. Header: {\"op\": -2}" + ); + } +} diff --git a/atrium-streams/src/subscriptions/handlers/mod.rs b/atrium-streams/src/subscriptions/handlers/mod.rs new file mode 100644 index 00000000..10de0374 --- /dev/null +++ b/atrium-streams/src/subscriptions/handlers/mod.rs @@ -0,0 +1,3 @@ +use super::{ConnectionHandler, ProcessedPayload}; + +pub mod repositories; diff --git a/atrium-streams/src/subscriptions/handlers/repositories.rs b/atrium-streams/src/subscriptions/handlers/repositories.rs new file mode 100644 index 00000000..9a4b0198 --- /dev/null +++ b/atrium-streams/src/subscriptions/handlers/repositories.rs @@ -0,0 +1,129 @@ +#![allow(unused_variables)] + +use std::future::Future; + +use atrium_api::com::atproto::sync::subscribe_repos; + +use super::{ConnectionHandler, ProcessedPayload}; + +/// This type should be used to define [`ConnectionHandler::HandledData`](ConnectionHandler::HandledData) +/// for the `com.atproto.sync.subscribeRepos` subscription type. +pub type HandledData = ProcessedData< + ::ProcessedCommitData, + ::ProcessedIdentityData, + ::ProcessedAccountData, + ::ProcessedHandleData, + ::ProcessedMigrateData, + ::ProcessedTombstoneData, + ::ProcessedInfoData, +>; + +/// Wrapper around all the possible types of processed data. +#[derive(Debug)] +pub enum ProcessedData { + Commit(C), + Identity(I0), + Account(A), + Handle(H), + Migrate(M), + Tombstone(T), + Info(I1), +} + +/// A trait that defines a [`ConnectionHandler`] specific to the +/// `com.atproto.sync.subscribeRepos` subscription type. +/// +/// Any struct that fully and correctly implements this trait will be able to +/// handle all the different payload types that the subscription can send. +/// Since the final desired result data type might change for each case, the +/// trait is generic, and the implementor must define the data type for each +/// payload they pretend to use. The same goes for the implementations of +/// each processing method, as the algorithm may vary. +pub trait Handler: ConnectionHandler { + type ProcessedCommitData; + /// Processes a payload of type `#commit`. + fn process_commit( + &self, + payload: subscribe_repos::Commit, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedIdentityData; + /// Processes a payload of type `#identity`. + fn process_identity( + &self, + payload: subscribe_repos::Identity, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedAccountData; + /// Processes a payload of type `#account`. + fn process_account( + &self, + payload: subscribe_repos::Account, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedHandleData; + /// Processes a payload of type `#handle`. + fn process_handle( + &self, + payload: subscribe_repos::Handle, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedMigrateData; + /// Processes a payload of type `#migrate`. + fn process_migrate( + &self, + payload: subscribe_repos::Migrate, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedTombstoneData; + /// Processes a payload of type `#tombstone`. + fn process_tombstone( + &self, + payload: subscribe_repos::Tombstone, + ) -> impl Future< + Output = Result< + Option>, + Self::HandlingError, + >, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } + + type ProcessedInfoData; + /// Processes a payload of type `#info`. + fn process_info( + &self, + payload: subscribe_repos::Info, + ) -> impl Future< + Output = Result>, Self::HandlingError>, + > { + // Default implementation always returns `None`, meaning the implementation decided to ignore the payload. + async { Ok(None) } + } +} diff --git a/atrium-streams/src/subscriptions/mod.rs b/atrium-streams/src/subscriptions/mod.rs new file mode 100644 index 00000000..46f113b7 --- /dev/null +++ b/atrium-streams/src/subscriptions/mod.rs @@ -0,0 +1,78 @@ +pub mod frames; +pub mod handlers; + +use std::{fmt::Debug, future::Future}; + +use futures::Stream; + +/// A trait that defines the connection handler. +pub trait ConnectionHandler { + /// The [`Self::HandledData`](ConnectionHandler::HandledData) type should be used to define the returned processed data type. + type HandledData; + /// The [`Self::HandlingError`](ConnectionHandler::HandlingError) type should be used to define the processing error type. + type HandlingError: 'static + Send + Sync + Debug; + + /// Handles binary data coming from the connection. This function will deserialize the payload body and call the appropriate + /// handler for each payload type. + /// + /// # Returns + /// [`Result>`] like: + /// - `Ok(Some(processedPayload))` where `processedPayload` is [`ProcessedPayload`](ProcessedPayload) + /// if the payload was successfully processed. + /// - `Ok(None)` if the payload was ignored. + /// - `Err(e)` where `e` is [`ConnectionHandler::HandlingError`] if an error occurred while processing the payload. + fn handle_payload( + &self, + t: String, + payload: Vec, + ) -> impl Future>, Self::HandlingError>>; +} + +/// A trait that defines a subscription. +/// It should be implemented by any struct that wants to handle a connection. +/// The `ConnectionPayload` type parameter is the type of the payload that will be received through the connection stream. +/// The `Error` type parameter is the type of the error that the specific subscription can return, following the lexicon. +pub trait Subscription { + /// The `handle_connection` method should be implemented to handle the connection. + /// + /// # Returns + /// A stream of processed payloads. + fn handle_connection( + connection: impl Stream + Unpin, + handler: H, + ) -> impl Stream, SubscriptionError>>; +} + +/// This struct represents a processed payload. +/// It contains the sequence number (cursor) and the final processed data. +pub struct ProcessedPayload { + pub seq: Option, // Might be absent, like in the case of #info. + pub data: Kind, +} + +/// Helper function to convert between payload kinds. +impl ProcessedPayload { + pub fn map NewKind>(self, f: F) -> ProcessedPayload { + ProcessedPayload { seq: self.seq, data: f(self.data) } + } +} + +/// An error type that represents a subscription error. +/// +/// `Abort` is a hard error, and the subscription should cancel. +/// This follows the [`ATProto Specs`](https://atproto.com/specs/event-stream). +/// +/// `Unknown` is an error that is not recognized by the subscription. +/// This can be used to handle unexpected errors. +/// +/// `Other` is an error specific to the subscription type. +/// This can be used to handle different kinds of errors, following the lexicon. +#[derive(Debug, thiserror::Error)] +pub enum SubscriptionError { + #[error("Critical Subscription Error: {0}")] + Abort(String), + #[error("Unknown Subscription Error: {0}")] + Unknown(String), + #[error(transparent)] + Other(T), +} diff --git a/examples/firehose/Cargo.toml b/examples/firehose/Cargo.toml index 70bab61d..d9fceaca 100644 --- a/examples/firehose/Cargo.toml +++ b/examples/firehose/Cargo.toml @@ -6,13 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.80" -atrium-api = { version = "0.18.1", features = ["dag-cbor"] } -chrono = "0.4.34" +anyhow = "1.0.86" +atrium-streams-client = { path = "../../atrium-streams-client" } futures = "0.3.30" -ipld-core = { version = "0.4.0", default-features = false, features = ["std"] } -rs-car = "0.4.1" -serde_ipld_dagcbor = { version = "0.6.0", default-features = false, features = ["std"] } -tokio = { version = "1.36.0", features = ["full"] } tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } -trait-variant = "0.1.1" +tokio = { version = "1.36.0", features = ["full"] } \ No newline at end of file diff --git a/examples/firehose/src/lib.rs b/examples/firehose/src/lib.rs deleted file mode 100644 index b4e04262..00000000 --- a/examples/firehose/src/lib.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod stream; -pub mod subscription; diff --git a/examples/firehose/src/main.rs b/examples/firehose/src/main.rs index e70a237e..ebd9e8b7 100644 --- a/examples/firehose/src/main.rs +++ b/examples/firehose/src/main.rs @@ -1,85 +1,154 @@ -use anyhow::{anyhow, Result}; -use atrium_api::app::bsky::feed::post::Record; -use atrium_api::com::atproto::sync::subscribe_repos::{Commit, NSID}; -use atrium_api::types::{CidLink, Collection}; -use chrono::Local; -use firehose::stream::frames::Frame; -use firehose::subscription::{CommitHandler, Subscription}; +use anyhow::bail; +use atrium_streams_client::{ + atrium_streams::{ + atrium_api::com::atproto::sync::subscribe_repos::{self, InfoData}, + client::EventStreamClient, + subscriptions::{ + handlers::repositories::ProcessedData, ProcessedPayload, SubscriptionError, + }, + }, + subscriptions::repositories::{ + firehose::Firehose, + type_defs::{Operation, ProcessedCommitData}, + Repositories, + }, + WssClient, Error, +}; use futures::StreamExt; -use tokio::net::TcpStream; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; +use tokio_tungstenite::tungstenite; -struct RepoSubscription { - stream: WebSocketStream>, +/// This example demonstrates how to connect to the ATProto Firehose. +#[tokio::main] +async fn main() { + // Define the Uri for the subscription. + let uri = format!("wss://bsky.network/xrpc/{}", subscribe_repos::NSID); + + // Caching the last cursor is important. + // The API has a backfilling mechanism that allows you to resume from where you stopped. + let mut last_cursor = None; + drop(connect(&mut last_cursor, uri).await); } -impl RepoSubscription { - async fn new(bgs: &str) -> Result> { - let (stream, _) = connect_async(format!("wss://{bgs}/xrpc/{NSID}")).await?; - Ok(RepoSubscription { stream }) - } - async fn run(&mut self, handler: impl CommitHandler) -> Result<(), Box> { - while let Some(result) = self.next().await { - if let Ok(Frame::Message(Some(t), message)) = result { - if t.as_str() == "#commit" { - let commit = serde_ipld_dagcbor::from_reader(message.body.as_slice())?; - if let Err(err) = handler.handle_commit(&commit).await { - eprintln!("FAILED: {err:?}"); - } +/// Connects to `ATProto` to receive real-time data. +async fn connect( + last_cursor: &mut Option, + uri: String, +) -> Result<(), anyhow::Error> { + // Define the query parameters. In this case, just the cursor. + let params = subscribe_repos::ParametersData { + cursor: *last_cursor, + }; + + // Build a new XRPC WSS Client. + let client = WssClient::builder() + .params(params) + .build(); + + // And then we connect to the API. + let connection = match client.connect(uri).await { + Ok(connection) => connection, + Err(Error::Connection(tungstenite::Error::Http(response))) => { + // According to the API documentation, the following status codes are expected and should be treated accordingly: + // 405 Method Not Allowed: Returned to client for non-GET HTTP requests to a stream endpoint. + // 426 Upgrade Required: Returned to client if Upgrade header is not included in a request to a stream endpoint. + // 429 Too Many Requests: Frequently used for rate-limiting. Client may try again after a delay. Support for the Retry-After header is encouraged. + // 500 Internal Server Error: Client may try again after a delay + // 501 Not Implemented: Service does not implement WebSockets or streams, at least for this endpoint. Client should not try again. + // 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout: Client may try again after a delay. + // https://atproto.com/specs/event-stream + bail!("Status Code was: {response:?}") + } + Err(e) => bail!(e), + }; + + // Builds the subscription handler + let firehose = Firehose::builder() + // You can enable or disable specific events, and every event is disabled by default. + // That way they don't get unnecessarily processed and you save up resources. + // Enable only the ones you plan to use. + .enable_commit(true) + .enable_info(true) + .build(); + + // Builds a new subscription from the connection, using handler provided + // by atrium-streams-client, the `Firehose`. + let mut subscription = Repositories::builder() + .connection(connection) + .handler(firehose) + .build(); + + // Receive payloads by calling `StreamExt::next()`. + while let Some(payload) = subscription.next().await { + let data = match payload { + Ok(ProcessedPayload { seq, data }) => { + if let Some(seq) = seq { + *last_cursor = Some(seq); } + data } - } - Ok(()) - } -} + Err(SubscriptionError::Abort(reason)) => { + // This could mean multiple things, all of which are critical errors that require + // immediate termination of connection. + eprintln!("Aborted: {reason}"); + *last_cursor = None; + break; + } + Err(e) => { + // Errors such as `FutureCursor` and `ConsumerTooSlow` can be dealt with here. + eprintln!("{e:?}"); + *last_cursor = None; + break; + } + }; -impl Subscription for RepoSubscription { - async fn next(&mut self) -> Option>::Error>> { - if let Some(Ok(Message::Binary(data))) = self.stream.next().await { - Some(Frame::try_from(data.as_slice())) - } else { - None - } + match data { + ProcessedData::Commit(data) => beauty_print_commit(data), + ProcessedData::Info(InfoData { message, name }) => { + println!("Received info. Message: {message:?}; Name: {name}."); + } + _ => { /* Ignored */ } + }; } -} -struct Firehose; + Ok(()) +} -impl CommitHandler for Firehose { - async fn handle_commit(&self, commit: &Commit) -> Result<()> { - for op in &commit.ops { - let collection = op.path.split('/').next().expect("op.path is empty"); - if op.action != "create" || collection != atrium_api::app::bsky::feed::Post::NSID { - continue; - } - let (items, _) = rs_car::car_read_all(&mut commit.blocks.as_slice(), true).await?; - if let Some((_, item)) = items.iter().find(|(cid, _)| Some(CidLink(*cid)) == op.cid) { - let record = serde_ipld_dagcbor::from_reader::(&mut item.as_slice())?; +fn beauty_print_commit(data: ProcessedCommitData) { + let ProcessedCommitData { + repo, commit, ops, .. + } = data; + if let Some(ops) = ops { + for r in ops { + let Operation { + action, + path, + record, + } = r; + let print = format!( + "\n\n\n################################# {} ##################################\n\ + - Repository (User DID): {}\n\ + - Commit CID: {}\n\ + - Path: {path}\n\ + - Flagged as \"too big\"? ", + action.to_uppercase(), + repo.as_str(), + commit.0, + ); + // Record is only `None` when the commit was flagged as "too big". + if let Some(record) = record { println!( - "{} - {}", - record.created_at.as_ref().with_timezone(&Local), - commit.repo.as_str() + "{}No\n\ + //-------------------------------- Record Info -------------------------------//\n\n\ + {:?}", + print, record ); - for line in record.text.split('\n') { - println!(" {line}"); - } } else { - return Err(anyhow!( - "FAILED: could not find item with operation cid {:?} out of {} items", - op.cid, - items.len() - )); + println!( + "{}Yes\n\ + //---------------------------------------------------------------------------//\n\n", + print + ); } } - Ok(()) } } - -#[tokio::main] -async fn main() -> Result<(), Box> { - RepoSubscription::new("bsky.network") - .await? - .run(Firehose) - .await -} diff --git a/examples/firehose/src/stream.rs b/examples/firehose/src/stream.rs deleted file mode 100644 index b63fcef5..00000000 --- a/examples/firehose/src/stream.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod frames; diff --git a/examples/firehose/src/stream/frames.rs b/examples/firehose/src/stream/frames.rs deleted file mode 100644 index 3edd1c1e..00000000 --- a/examples/firehose/src/stream/frames.rs +++ /dev/null @@ -1,158 +0,0 @@ -use ipld_core::ipld::Ipld; -use std::io::Cursor; - -// original definition: -//``` -// export enum FrameType { -// Message = 1, -// Error = -1, -// } -// export const messageFrameHeader = z.object({ -// op: z.literal(FrameType.Message), // Frame op -// t: z.string().optional(), // Message body type discriminator -// }) -// export type MessageFrameHeader = z.infer -// export const errorFrameHeader = z.object({ -// op: z.literal(FrameType.Error), -// }) -// export type ErrorFrameHeader = z.infer -// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -enum FrameHeader { - Message(Option), - Error, -} - -impl TryFrom for FrameHeader { - type Error = anyhow::Error; - - fn try_from(value: Ipld) -> Result>::Error> { - if let Ipld::Map(map) = value { - if let Some(Ipld::Integer(i)) = map.get("op") { - match i { - 1 => { - let t = if let Some(Ipld::String(s)) = map.get("t") { - Some(s.clone()) - } else { - None - }; - return Ok(FrameHeader::Message(t)); - } - -1 => return Ok(FrameHeader::Error), - _ => {} - } - } - } - Err(anyhow::anyhow!("invalid frame type")) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Frame { - Message(Option, MessageFrame), - Error(ErrorFrame), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct MessageFrame { - pub body: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ErrorFrame { - // TODO - // body: Value, -} - -impl TryFrom<&[u8]> for Frame { - type Error = anyhow::Error; - - fn try_from(value: &[u8]) -> Result>::Error> { - let mut cursor = Cursor::new(value); - let (left, right) = match serde_ipld_dagcbor::from_reader::(&mut cursor) { - Err(serde_ipld_dagcbor::DecodeError::TrailingData) => { - value.split_at(cursor.position() as usize) - } - _ => { - // TODO - return Err(anyhow::anyhow!("invalid frame type")); - } - }; - let header = FrameHeader::try_from(serde_ipld_dagcbor::from_slice::(left)?)?; - if let FrameHeader::Message(t) = &header { - Ok(Frame::Message( - t.clone(), - MessageFrame { - body: right.to_vec(), - }, - )) - } else { - Ok(Frame::Error(ErrorFrame {})) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn serialized_data(s: &str) -> Vec { - assert!(s.len() % 2 == 0); - let b2u = |b: u8| match b { - b'0'..=b'9' => b - b'0', - b'a'..=b'f' => b - b'a' + 10, - _ => unreachable!(), - }; - s.as_bytes() - .chunks(2) - .map(|b| (b2u(b[0]) << 4) + b2u(b[1])) - .collect() - } - - #[test] - fn deserialize_message_frame_header() { - // {"op": 1, "t": "#commit"} - let data = serialized_data("a2626f700161746723636f6d6d6974"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect("failed to deserialize"), - FrameHeader::Message(Some(String::from("#commit"))) - ); - } - - #[test] - fn deserialize_error_frame_header() { - // {"op": -1} - let data = serialized_data("a1626f7020"); - let ipld = serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!(result.expect("failed to deserialize"), FrameHeader::Error); - } - - #[test] - fn deserialize_invalid_frame_header() { - { - // {"op": 2, "t": "#commit"} - let data = serialized_data("a2626f700261746723636f6d6d6974"); - let ipld = - serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect_err("must be failed").to_string(), - "invalid frame type" - ); - } - { - // {"op": -2} - let data = serialized_data("a1626f7021"); - let ipld = - serde_ipld_dagcbor::from_slice::(&data).expect("failed to deserialize"); - let result = FrameHeader::try_from(ipld); - assert_eq!( - result.expect_err("must be failed").to_string(), - "invalid frame type" - ); - } - } -} diff --git a/examples/firehose/src/subscription.rs b/examples/firehose/src/subscription.rs deleted file mode 100644 index 90393105..00000000 --- a/examples/firehose/src/subscription.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::stream::frames::Frame; -use anyhow::Result; -use atrium_api::com::atproto::sync::subscribe_repos::Commit; -use std::future::Future; - -#[trait_variant::make(HttpService: Send)] -pub trait Subscription { - async fn next(&mut self) -> Option>::Error>>; -} - -pub trait CommitHandler { - fn handle_commit(&self, commit: &Commit) -> impl Future>; -}