diff --git a/.all-contributorsrc b/.all-contributorsrc index 1db94a0b..95b09876 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -175,11 +175,416 @@ "contributions": [ "code" ] + }, + { + "login": "Nurrl", + "name": "Maya the bee", + "avatar_url": "https://avatars.githubusercontent.com/u/15341887?v=4", + "profile": "https://github.com/Nurrl", + "contributions": [ + "code" + ] + }, + { + "login": "mmirate", + "name": "Milo Mirate", + "avatar_url": "https://avatars.githubusercontent.com/u/992859?v=4", + "profile": "https://github.com/mmirate", + "contributions": [ + "code" + ] + }, + { + "login": "george-hopkins", + "name": "George Hopkins", + "avatar_url": "https://avatars.githubusercontent.com/u/552590?v=4", + "profile": "https://github.com/george-hopkins", + "contributions": [ + "code" + ] + }, + { + "login": "akeamc", + "name": "Åke Amcoff", + "avatar_url": "https://avatars.githubusercontent.com/u/17624114?v=4", + "profile": "https://amcoff.net/", + "contributions": [ + "code" + ] + }, + { + "login": "bho01", + "name": "Brendon Ho", + "avatar_url": "https://avatars.githubusercontent.com/u/12106620?v=4", + "profile": "http://brendonho.com", + "contributions": [ + "code" + ] + }, + { + "login": "samuela", + "name": "Samuel Ainsworth", + "avatar_url": "https://avatars.githubusercontent.com/u/226872?v=4", + "profile": "http://samlikes.pizza/", + "contributions": [ + "code" + ] + }, + { + "login": "sherlock-holo", + "name": "Sherlock Holo", + "avatar_url": "https://avatars.githubusercontent.com/u/10096425?v=4", + "profile": "https://github.com/Sherlock-Holo", + "contributions": [ + "code" + ] + }, + { + "login": "ricott1", + "name": "Alessandro Ricottone", + "avatar_url": "https://avatars.githubusercontent.com/u/16502243?v=4", + "profile": "https://github.com/ricott1", + "contributions": [ + "code" + ] + }, + { + "login": "T0b1-iOS", + "name": "T0b1-iOS", + "avatar_url": "https://avatars.githubusercontent.com/u/15174814?v=4", + "profile": "https://github.com/T0b1-iOS", + "contributions": [ + "code" + ] + }, + { + "login": "shoaibmerchant", + "name": "Shoaib Merchant", + "avatar_url": "https://avatars.githubusercontent.com/u/4598631?v=4", + "profile": "https://mecha.so", + "contributions": [ + "code" + ] + }, + { + "login": "gleason-m", + "name": "Michael Gleason", + "avatar_url": "https://avatars.githubusercontent.com/u/86493344?v=4", + "profile": "https://github.com/gleason-m", + "contributions": [ + "code" + ] + }, + { + "login": "elegaanz", + "name": "Ana Gelez", + "avatar_url": "https://avatars.githubusercontent.com/u/16254623?v=4", + "profile": "https://ana.gelez.xyz", + "contributions": [ + "code" + ] + }, + { + "login": "tomknig", + "name": "Tom König", + "avatar_url": "https://avatars.githubusercontent.com/u/3586316?v=4", + "profile": "https://github.com/tomknig", + "contributions": [ + "code" + ] + }, + { + "login": "Barre", + "name": "Pierre Barre", + "avatar_url": "https://avatars.githubusercontent.com/u/45085843?v=4", + "profile": "https://www.legaltile.com/", + "contributions": [ + "code" + ] + }, + { + "login": "spoutn1k", + "name": "Jean-Baptiste Skutnik", + "avatar_url": "https://avatars.githubusercontent.com/u/22240065?v=4", + "profile": "http://skutnik.page", + "contributions": [ + "code" + ] + }, + { + "login": "packetsource", + "name": "Adam Chappell", + "avatar_url": "https://avatars.githubusercontent.com/u/6276475?v=4", + "profile": "http://blog.packetsource.net/", + "contributions": [ + "code" + ] + }, + { + "login": "CertainLach", + "name": "Yaroslav Bolyukin", + "avatar_url": "https://avatars.githubusercontent.com/u/6235312?v=4", + "profile": "https://github.com/CertainLach", + "contributions": [ + "code" + ] + }, + { + "login": "JuliDi", + "name": "Julian", + "avatar_url": "https://avatars.githubusercontent.com/u/20155974?v=4", + "profile": "http://www.systemscape.de", + "contributions": [ + "code" + ] + }, + { + "login": "grampelberg", + "name": "Thomas Rampelberg", + "avatar_url": "https://avatars.githubusercontent.com/u/47992?v=4", + "profile": "http://saunter.org", + "contributions": [ + "code" + ] + }, + { + "login": "belak", + "name": "Kaleb Elwert", + "avatar_url": "https://avatars.githubusercontent.com/u/107097?v=4", + "profile": "https://belak.io", + "contributions": [ + "doc" + ] + }, + { + "login": "nbdd0121", + "name": "Gary Guo", + "avatar_url": "https://avatars.githubusercontent.com/u/4065244?v=4", + "profile": "https://garyguo.net", + "contributions": [ + "code" + ] + }, + { + "login": "irvingoujAtDevolution", + "name": "irvingouj @ Devolutions", + "avatar_url": "https://avatars.githubusercontent.com/u/139169536?v=4", + "profile": "https://github.com/irvingoujAtDevolution", + "contributions": [ + "code" + ] + }, + { + "login": "Tehforsch", + "name": "Toni Peter", + "avatar_url": "https://avatars.githubusercontent.com/u/4614215?v=4", + "profile": "http://tonipeter.de", + "contributions": [ + "code" + ] + }, + { + "login": "Nathy-bajo", + "name": "Nathaniel Bajo", + "avatar_url": "https://avatars.githubusercontent.com/u/73991674?v=4", + "profile": "https://github.com/Nathy-bajo", + "contributions": [ + "code" + ] + }, + { + "login": "EpicEric", + "name": "Eric Rodrigues Pires", + "avatar_url": "https://avatars.githubusercontent.com/u/3129194?v=4", + "profile": "https://eric.dev.br", + "contributions": [ + "code" + ] + }, + { + "login": "jeromegn", + "name": "Jerome Gravel-Niquet", + "avatar_url": "https://avatars.githubusercontent.com/u/43325?v=4", + "profile": "http://www.fly.io", + "contributions": [ + "code" + ] + }, + { + "login": "qsantos", + "name": "Quentin Santos", + "avatar_url": "https://avatars.githubusercontent.com/u/8493765?v=4", + "profile": "https://qsantos.fr/", + "contributions": [ + "doc" + ] + }, + { + "login": "ogedei-khan", + "name": "André Almeida", + "avatar_url": "https://avatars.githubusercontent.com/u/181673956?v=4", + "profile": "https://github.com/ogedei-khan", + "contributions": [ + "code" + ] + }, + { + "login": "snaggen", + "name": "Mattias Eriksson", + "avatar_url": "https://avatars.githubusercontent.com/u/6420639?v=4", + "profile": "https://github.com/snaggen", + "contributions": [ + "code" + ] + }, + { + "login": "joshka", + "name": "Josh McKinney", + "avatar_url": "https://avatars.githubusercontent.com/u/381361?v=4", + "profile": "http://joshka.net", + "contributions": [ + "code" + ] + }, + { + "login": "citorva", + "name": "citorva", + "avatar_url": "https://avatars.githubusercontent.com/u/16229435?v=4", + "profile": "https://citorva.fr/", + "contributions": [ + "code" + ] + }, + { + "login": "eric-seppanen", + "name": "Eric Seppanen", + "avatar_url": "https://avatars.githubusercontent.com/u/109770420?v=4", + "profile": "https://github.com/eric-seppanen", + "contributions": [ + "code" + ] + }, + { + "login": "ericseppanen", + "name": "Eric Seppanen", + "avatar_url": "https://avatars.githubusercontent.com/u/36317762?v=4", + "profile": "https://codeandbitters.com/", + "contributions": [ + "code" + ] + }, + { + "login": "Patryk27", + "name": "Patryk Wychowaniec", + "avatar_url": "https://avatars.githubusercontent.com/u/3395477?v=4", + "profile": "https://pwy.io", + "contributions": [ + "code" + ] + }, + { + "login": "RandyMcMillan", + "name": "@RandyMcMillan", + "avatar_url": "https://avatars.githubusercontent.com/u/152159?v=4", + "profile": "https://www.randymcmillan.net", + "contributions": [ + "code" + ] + }, + { + "login": "handewo", + "name": "handewo", + "avatar_url": "https://avatars.githubusercontent.com/u/20971373?v=4", + "profile": "https://github.com/handewo", + "contributions": [ + "code" + ] + }, + { + "login": "ccbrown", + "name": "Chris", + "avatar_url": "https://avatars.githubusercontent.com/u/1731074?v=4", + "profile": "https://github.com/ccbrown", + "contributions": [ + "code" + ] + }, + { + "login": "procr1337", + "name": "procr1337", + "avatar_url": "https://avatars.githubusercontent.com/u/193802945?v=4", + "profile": "https://github.com/procr1337", + "contributions": [ + "code" + ] + }, + { + "login": "Itsusinn", + "name": "iHsin", + "avatar_url": "https://avatars.githubusercontent.com/u/30529002?v=4", + "profile": "https://github.com/Itsusinn", + "contributions": [ + "code" + ] + }, + { + "login": "psychon", + "name": "Uli Schlachter", + "avatar_url": "https://avatars.githubusercontent.com/u/89482?v=4", + "profile": "https://github.com/psychon", + "contributions": [ + "code" + ] + }, + { + "login": "jvanbrunt", + "name": "Jacob Van Brunt", + "avatar_url": "https://avatars.githubusercontent.com/u/3064793?v=4", + "profile": "https://github.com/jvanbrunt", + "contributions": [ + "code" + ] + }, + { + "login": "lgmugnier", + "name": "lgmugnier", + "avatar_url": "https://avatars.githubusercontent.com/u/10800317?v=4", + "profile": "https://github.com/lgmugnier", + "contributions": [ + "code" + ] + }, + { + "login": "MingweiSamuel", + "name": "Mingwei Samuel", + "avatar_url": "https://avatars.githubusercontent.com/u/6778341?v=4", + "profile": "https://github.com/MingweiSamuel", + "contributions": [ + "code" + ] + }, + { + "login": "pgrange", + "name": "Pascal Grange", + "avatar_url": "https://avatars.githubusercontent.com/u/378506?v=4", + "profile": "https://twitter.com/pascalgrange", + "contributions": [ + "code" + ] + }, + { + "login": "wyhaya", + "name": "wyhaya", + "avatar_url": "https://avatars.githubusercontent.com/u/23690145?v=4", + "profile": "https://github.com/wyhaya", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, "projectName": "russh", - "projectOwner": "warp-tech", + "projectOwner": "Eugeny", "repoType": "github", "repoHost": "https://github.com", "skipCi": true, diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..c236b04c --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +github: eugeny +open_collective: tabby +ko_fi: eugeny diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 66ef82a8..a0164e99 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,16 +2,16 @@ name: Rust on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] env: CARGO_TERM_COLOR: always jobs: Build: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v2 @@ -22,13 +22,51 @@ jobs: - name: Build (all features enabled) run: cargo build --verbose --all-features - - name: Check semver compatibility (russh) - uses: obi1kenobi/cargo-semver-checks-action@v2 - with: - package: russh + Build-Windows: + runs-on: windows-latest + + steps: + - uses: actions/checkout@v2 + + - name: install nasm + run: | + choco install nasm + echo "C:\Program Files\NASM" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + + - name: Build (no features enabled) + run: cargo build --verbose + + - name: Build (all features enabled) + run: cargo build --verbose --all-features + + Build-WASM: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v2 + + - name: Install target + run: | + rustup toolchain add 1.81.0 + rustup target add --toolchain 1.81.0 wasm32-wasip1 + + - name: Build (WASM-compatible features) + run: cargo +1.81.0 build --verbose --target wasm32-wasip1 --no-default-features --features flate2,ring -p russh + + Formatting: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v2 + + - name: Install rustfmt + run: rustup component add rustfmt + + - name: rustfmt + run: cargo fmt --check Clippy: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v2 @@ -43,7 +81,7 @@ jobs: run: cargo clippy --all-features -- -D warnings Test: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v2 @@ -61,3 +99,16 @@ jobs: cargo test --verbose --all-features env: RUST_BACKTRACE: 1 + + Minimal-versions: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v2 + - uses: taiki-e/install-action@cargo-hack + - uses: taiki-e/install-action@cargo-minimal-versions + + - name: Check with minimal dependency versions + run: | + rustup toolchain add 1.75.0 + cargo +1.75.0 minimal-versions check --all-features --no-dev-deps diff --git a/.github/workflows/semver.yml b/.github/workflows/semver.yml new file mode 100644 index 00000000..90a8ca4f --- /dev/null +++ b/.github/workflows/semver.yml @@ -0,0 +1,26 @@ +name: Semver check + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +env: + CARGO_TERM_COLOR: always + +jobs: + Build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Build (no features enabled) + run: cargo build --verbose + + - name: Check semver compatibility (russh) + uses: obi1kenobi/cargo-semver-checks-action@v2 + continue-on-error: true + with: + package: russh diff --git a/.gitignore b/.gitignore index abb7748a..c2d07934 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target Cargo.lock .cargo-ok +ca-test* diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..8031f7f6 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "rust-analyzer.check.command": "check" +} diff --git a/.well-known/funding-manifest-urls b/.well-known/funding-manifest-urls new file mode 100644 index 00000000..c510488c --- /dev/null +++ b/.well-known/funding-manifest-urls @@ -0,0 +1 @@ +https://null.page/funding.json diff --git a/Cargo.toml b/Cargo.toml index d4ed211c..766b149d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,36 @@ [workspace] -members = [ "russh-keys", "russh", "russh-config", "cryptovec"] +members = ["russh", "russh-config", "cryptovec", "pageant", "russh-util"] +resolver = "2" -[patch.crates-io] -russh = { path = "russh" } -russh-keys = { path = "russh-keys" } -russh-cryptovec = { path = "cryptovec" } -russh-config = { path = "russh-config" } +[workspace.dependencies] +aes = "0.8" +async-trait = "0.1.50" +byteorder = "1.4" +bytes = "1.7" +digest = "0.10" +delegate = "0.13" +env_logger = "0.6" +futures = "0.3" +home = "0.5" +hmac = "0.12" +log = "0.4.11" +rand = "0.8" +rsa = "0.9" +sha1 = { version = "0.10.5", features = ["oid"] } +sha2 = { version = "0.10.6", features = ["oid"] } +signature = "2.2" +ssh-encoding = { version = "0.2", features = ["bytes"] } +ssh-key = { version = "=0.6.11", features = [ + "ed25519", + "rsa", + "rsa-sha1", + "p256", + "p384", + "p521", + "encryption", + "ppk", + "hazmat-allow-insecure-rsa-keys", +], package = "internal-russh-forked-ssh-key" } +thiserror = "1.0.30" +tokio = { version = "1.17.0" } +tokio-stream = { version = "0.1.3", features = ["net", "sync"] } diff --git a/README.md b/README.md index 76346d74..22184062 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,45 @@ # Russh + [![Rust](https://github.com/warp-tech/russh/actions/workflows/rust.yml/badge.svg)](https://github.com/warp-tech/russh/actions/workflows/rust.yml) -[![All Contributors](https://img.shields.io/badge/all_contributors-19-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-64-orange.svg?style=flat-square)](#contributors-) Low-level Tokio SSH2 client and server implementation. +Examples: [simple client](russh/examples/client_exec_simple.rs), [interactive PTY client](russh/examples/client_exec_interactive.rs), [server](russh/examples/echoserver.rs), [SFTP client](russh/examples/sftp_client.rs), [SFTP server](russh/examples/sftp_server.rs). + This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Étienne Meunier. > ✨ = added in Russh * [More panic safety](https://github.com/warp-tech/russh#safety) ✨ -* `async_trait` support ✨ +* async traits ✨ * `direct-tcpip` (local port forwarding) * `forward-tcpip` (remote port forwarding) ✨ * `direct-streamlocal` (local UNIX socket forwarding, client only) ✨ +* `forward-streamlocal` (remote UNIX socket forwarding) ✨ * Ciphers: * `chacha20-poly1305@openssh.com` + * `aes128-gcm@openssh.com` ✨ * `aes256-gcm@openssh.com` ✨ * `aes256-ctr` ✨ * `aes192-ctr` ✨ * `aes128-ctr` ✨ + * `aes256-cbc` ✨ + * `aes192-cbc` ✨ + * `aes128-cbc` ✨ + * `3des-cbc` ✨ * Key exchanges: * `curve25519-sha256@libssh.org` + * `diffie-hellman-group-sha1` (GEX) ✨ * `diffie-hellman-group1-sha1` ✨ * `diffie-hellman-group14-sha1` ✨ + * `diffie-hellman-group-sha256` (GEX) ✨ * `diffie-hellman-group14-sha256` ✨ + * `diffie-hellman-group16-sha512` ✨ + * `ecdh-sha2-nistp256` ✨ + * `ecdh-sha2-nistp384` ✨ + * `ecdh-sha2-nistp521` ✨ * MACs: * `hmac-sha1` ✨ * `hmac-sha2-256` ✨ @@ -32,15 +47,27 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti * `hmac-sha1-etm@openssh.com` ✨ * `hmac-sha2-256-etm@openssh.com` ✨ * `hmac-sha2-512-etm@openssh.com` ✨ -* Host keys: +* Host keys and public key auth: * `ssh-ed25519` * `rsa-sha2-256` * `rsa-sha2-512` * `ssh-rsa` ✨ + * `ecdsa-sha2-nistp256` ✨ + * `ecdsa-sha2-nistp384` ✨ + * `ecdsa-sha2-nistp521` ✨ +* Authentication methods: + * `password` + * `publickey` + * `keyboard-interactive` + * `none` + * OpenSSH certificates ✨ * Dependency updates * OpenSSH keepalive request handling ✨ * OpenSSH agent forwarding channels ✨ * OpenSSH `server-sig-algs` extension ✨ +* PPK key format ✨ +* Pageant support ✨ +* `AsyncRead`/`AsyncWrite`-able channels ✨ ## Safety @@ -53,6 +80,7 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti ### Panics * When the Rust allocator fails to allocate memory during a CryptoVec being resized. +* When `mlock`/`munlock` fails to protect sensitive data in memory. ### Unsafe code @@ -60,9 +88,32 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti ## Ecosystem -* [russh-sftp](https://crates.io/crates/russh-sftp) - server-side SFTP subsystem support for `russh` - see `russh/examples/sftp_server.rs`. +* [russh-sftp](https://crates.io/crates/russh-sftp) - server-side and client-side SFTP subsystem support for `russh` - see `russh/examples/sftp_server.rs` or `russh/examples/sftp_client.rs`. * [async-ssh2-tokio](https://crates.io/crates/async-ssh2-tokio) - simple high-level API for running commands over SSH. +## Adopters + +* [HexPatch](https://github.com/Etto48/HexPatch) - A binary patcher and editor written in Rust with terminal user interface (TUI). + * Uses `russh::client` and `russh_sftp::client` to allow remote editing of files. +* [kartoffels](https://github.com/Patryk27/kartoffels) - A game where you're given a potato and your job is to implement a firmware for it + * Uses `russh:server` to deliver the game, using `ratatui` as the rendering engine. +* [kty](https://github.com/grampelberg/kty) - The terminal for Kubernetes. + * Uses `russh::server` to deliver the `ratatui` based TUI and `russh_sftp::server` to provide `scp` based file management. +* [lapdev](https://github.com/lapce/lapdev) - Self-Hosted Remote Dev Environment + * Uses `russh::server` to construct a proxy into your development environment. +* [medusa](https://github.com/evilsocket/medusa) - A fast and secure multi protocol honeypot. + * Uses `russh::server` to be the basis of the honeypot. +* [rebels-in-the-sky](https://github.com/ricott1/rebels-in-the-sky) - P2P terminal game about spacepirates playing basketball across the galaxy + * Uses `russh::server` to deliver the game, using `ratatui` as the rendering engine. +* [warpgate](https://github.com/warp-tech/warpgate) - Smart SSH, HTTPS and MySQL bastion that requires no additional client-side software + * Uses `russh::server` in addition to `russh::client` as part of the smart SSH functionality. +* [Devolutions Gateway](https://github.com/Devolutions/devolutions-gateway/) - Establish a secure entry point for internal or external segmented networks that require authorized just-in-time (JIT) access. + * Uses `russh::client` for the web-based SSH client of the standalone web application. +* [Sandhole](https://github.com/EpicEric/sandhole) - Expose HTTP/SSH/TCP services through SSH port forwarding. A reverse proxy that just works with an OpenSSH client. + * Uses `russh::server` for reverse forwarding connections, local forwarding tunnels, and the `ratatui` based admin interface. +* [Motor OS](https://github.com/moturus/motor-os) - A new Rust-based operating system for VMs. + * Uses `russh::server` as the base for its own [SSH Server](https://github.com/moturus/motor-os/tree/main/src/bin/russhd). + ## Contributors ✨ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): @@ -73,29 +124,88 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d - - - - - - - + + + + + + + - - - - - - - + + + + + + + - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Mihir Samdarshi
Mihir Samdarshi

📖
Connor Peet
Connor Peet

💻
KVZN
KVZN

💻
Adrian Müller (DTT)
Adrian Müller (DTT)

💻
Simone Margaritelli
Simone Margaritelli

💻
Joe Grund
Joe Grund

💻
AspectUnk
AspectUnk

💻
Mihir Samdarshi
Mihir Samdarshi

📖
Connor Peet
Connor Peet

💻
KVZN
KVZN

💻
Adrian Müller (DTT)
Adrian Müller (DTT)

💻
Simone Margaritelli
Simone Margaritelli

💻
Joe Grund
Joe Grund

💻
AspectUnk
AspectUnk

💻
Simão Mata
Simão Mata

💻
Mariotaku
Mariotaku

💻
yorkz1994
yorkz1994

💻
Ciprian Dorin Craciun
Ciprian Dorin Craciun

💻
Eric Milliken
Eric Milliken

💻
Swelio
Swelio

💻
Joshua Benz
Joshua Benz

💻
Simão Mata
Simão Mata

💻
Mariotaku
Mariotaku

💻
yorkz1994
yorkz1994

💻
Ciprian Dorin Craciun
Ciprian Dorin Craciun

💻
Eric Milliken
Eric Milliken

💻
Swelio
Swelio

💻
Joshua Benz
Joshua Benz

💻
Jan Holthuis
Jan Holthuis

🛡️
mateuszkj
mateuszkj

💻
Saksham Mittal
Saksham Mittal

💻
Lucas Kent
Lucas Kent

💻
Raphael Druon
Raphael Druon

💻
mateuszkj
mateuszkj

💻
Saksham Mittal
Saksham Mittal

💻
Lucas Kent
Lucas Kent

💻
Raphael Druon
Raphael Druon

💻
Maya the bee
Maya the bee

💻
Milo Mirate
Milo Mirate

💻
George Hopkins
George Hopkins

💻
Åke Amcoff
Åke Amcoff

💻
Brendon Ho
Brendon Ho

💻
Samuel Ainsworth
Samuel Ainsworth

💻
Sherlock Holo
Sherlock Holo

💻
Alessandro Ricottone
Alessandro Ricottone

💻
T0b1-iOS
T0b1-iOS

💻
Shoaib Merchant
Shoaib Merchant

💻
Michael Gleason
Michael Gleason

💻
Ana Gelez
Ana Gelez

💻
Tom König
Tom König

💻
Pierre Barre
Pierre Barre

💻
Jean-Baptiste Skutnik
Jean-Baptiste Skutnik

💻
Adam Chappell
Adam Chappell

💻
Yaroslav Bolyukin
Yaroslav Bolyukin

💻
Julian
Julian

💻
Thomas Rampelberg
Thomas Rampelberg

💻
Kaleb Elwert
Kaleb Elwert

📖
Gary Guo
Gary Guo

💻
irvingouj @ Devolutions
irvingouj @ Devolutions

💻
Toni Peter
Toni Peter

💻
Nathaniel Bajo
Nathaniel Bajo

💻
Eric Rodrigues Pires
Eric Rodrigues Pires

💻
Jerome Gravel-Niquet
Jerome Gravel-Niquet

💻
Quentin Santos
Quentin Santos

📖
André Almeida
André Almeida

💻
Mattias Eriksson
Mattias Eriksson

💻
Josh McKinney
Josh McKinney

💻
citorva
citorva

💻
Eric Seppanen
Eric Seppanen

💻
Eric Seppanen
Eric Seppanen

💻
Patryk Wychowaniec
Patryk Wychowaniec

💻
@RandyMcMillan
@RandyMcMillan

💻
handewo
handewo

💻
Chris
Chris

💻
procr1337
procr1337

💻
iHsin
iHsin

💻
Uli Schlachter
Uli Schlachter

💻
Jacob Van Brunt
Jacob Van Brunt

💻
lgmugnier
lgmugnier

💻
Mingwei Samuel
Mingwei Samuel

💻
Pascal Grange
Pascal Grange

💻
wyhaya
wyhaya

💻
diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..f2d775ae --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,7 @@ +# Security Policy + +## Reporting a Vulnerability + +Please report vunerabilities using GitHub's Private Vulnerability Reporting tool. + +You can expect a response within a few days. diff --git a/bench.sh b/bench.sh new file mode 100755 index 00000000..713b7e0f --- /dev/null +++ b/bench.sh @@ -0,0 +1,2 @@ +#!/bin/sh +RUSTFLAGS="-Ctarget-cpu=native" cargo bench -F _bench diff --git a/cryptovec/Cargo.toml b/cryptovec/Cargo.toml index 92450230..2f861c83 100644 --- a/cryptovec/Cargo.toml +++ b/cryptovec/Cargo.toml @@ -2,13 +2,33 @@ authors = ["Pierre-Étienne Meunier ", "Eugeny ) -> std::fmt::Result { + if self.size == 0 { + return f.write_str(""); + } + write!(f, "<{:?}>", self.size) + } +} + +impl Unpin for CryptoVec {} +unsafe impl Send for CryptoVec {} +unsafe impl Sync for CryptoVec {} + +// Common traits implementations +impl AsRef<[u8]> for CryptoVec { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for CryptoVec { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +impl Deref for CryptoVec { + type Target = [u8]; + fn deref(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.p, self.size) } + } +} + +impl DerefMut for CryptoVec { + fn deref_mut(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.p, self.size) } + } +} + +impl From for CryptoVec { + fn from(e: String) -> Self { + CryptoVec::from(e.into_bytes()) + } +} + +impl From<&str> for CryptoVec { + fn from(e: &str) -> Self { + CryptoVec::from(e.as_bytes()) + } +} + +impl From<&[u8]> for CryptoVec { + fn from(e: &[u8]) -> Self { + CryptoVec::from_slice(e) + } +} + +impl From> for CryptoVec { + fn from(e: Vec) -> Self { + let mut c = CryptoVec::new_zeroed(e.len()); + c.clone_from_slice(&e[..]); + c + } +} + +// Indexing implementations +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: RangeFrom) -> &[u8] { + self.deref().index(index) + } +} +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: RangeTo) -> &[u8] { + self.deref().index(index) + } +} +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: Range) -> &[u8] { + self.deref().index(index) + } +} +impl Index for CryptoVec { + type Output = [u8]; + fn index(&self, _: RangeFull) -> &[u8] { + self.deref() + } +} + +impl IndexMut for CryptoVec { + fn index_mut(&mut self, _: RangeFull) -> &mut [u8] { + self.deref_mut() + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: RangeFrom) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: RangeTo) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: Range) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} + +impl Index for CryptoVec { + type Output = u8; + fn index(&self, index: usize) -> &u8 { + self.deref().index(index) + } +} + +// IO-related implementation +impl std::io::Write for CryptoVec { + fn write(&mut self, buf: &[u8]) -> Result { + self.extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> Result<(), std::io::Error> { + Ok(()) + } +} + +// Default implementation +impl Default for CryptoVec { + fn default() -> Self { + CryptoVec { + p: std::ptr::NonNull::dangling().as_ptr(), + size: 0, + capacity: 0, + } + } +} + +impl CryptoVec { + /// Creates a new `CryptoVec`. + pub fn new() -> CryptoVec { + CryptoVec::default() + } + + /// Creates a new `CryptoVec` with `n` zeros. + pub fn new_zeroed(size: usize) -> CryptoVec { + unsafe { + let capacity = size.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); + let p = std::alloc::alloc_zeroed(layout); + let _ = mlock(p, capacity); + CryptoVec { p, capacity, size } + } + } + + /// Creates a new `CryptoVec` with capacity `capacity`. + pub fn with_capacity(capacity: usize) -> CryptoVec { + unsafe { + let capacity = capacity.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); + let p = std::alloc::alloc_zeroed(layout); + let _ = mlock(p, capacity); + CryptoVec { + p, + capacity, + size: 0, + } + } + } + + /// Length of this `CryptoVec`. + /// + /// ``` + /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0) + /// ``` + pub fn len(&self) -> usize { + self.size + } + + /// Returns `true` if and only if this CryptoVec is empty. + /// + /// ``` + /// assert!(russh_cryptovec::CryptoVec::new().is_empty()) + /// ``` + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Resize this CryptoVec, appending zeros at the end. This may + /// perform at most one reallocation, overwriting the previous + /// version with zeros. + pub fn resize(&mut self, size: usize) { + if size <= self.capacity && size > self.size { + // If this is an expansion, just resize. + self.size = size + } else if size <= self.size { + // If this is a truncation, resize and erase the extra memory. + unsafe { + memset(self.p.add(size), 0, self.size - size); + } + self.size = size; + } else { + // realloc ! and erase the previous memory. + unsafe { + let next_capacity = size.next_power_of_two(); + let old_ptr = self.p; + let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); + self.p = std::alloc::alloc_zeroed(next_layout); + let _ = mlock(self.p, next_capacity); + + if self.capacity > 0 { + std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); + for i in 0..self.size { + std::ptr::write_volatile(old_ptr.add(i), 0) + } + let _ = munlock(old_ptr, self.capacity); + let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); + std::alloc::dealloc(old_ptr, layout); + } + + if self.p.is_null() { + #[allow(clippy::panic)] + { + panic!("Realloc failed, pointer = {:?} {:?}", self, size) + } + } else { + self.capacity = next_capacity; + self.size = size; + } + } + } + } + + /// Clear this CryptoVec (retaining the memory). + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"blabla"); + /// v.clear(); + /// assert!(v.is_empty()) + /// ``` + pub fn clear(&mut self) { + self.resize(0); + } + + /// Append a new byte at the end of this CryptoVec. + pub fn push(&mut self, s: u8) { + let size = self.size; + self.resize(size + 1); + unsafe { *self.p.add(size) = s } + } + + /// Read `n_bytes` from `r`, and append them at the end of this + /// `CryptoVec`. Returns the number of bytes read (and appended). + pub fn read( + &mut self, + n_bytes: usize, + mut r: R, + ) -> Result { + let cur_size = self.size; + self.resize(cur_size + n_bytes); + let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) }; + // Resize the buffer to its appropriate size. + match r.read(s) { + Ok(n) => { + self.resize(cur_size + n); + Ok(n) + } + Err(e) => { + self.resize(cur_size); + Err(e) + } + } + } + + /// Write all this CryptoVec to the provided `Write`. Returns the + /// number of bytes actually written. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"blabla"); + /// let mut s = std::io::stdout(); + /// v.write_all_from(0, &mut s).unwrap(); + /// ``` + pub fn write_all_from( + &self, + offset: usize, + mut w: W, + ) -> Result { + assert!(offset < self.size); + // if we're past this point, self.p cannot be null. + unsafe { + let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset); + w.write(s) + } + } + + /// Resize this CryptoVec, returning a mutable borrow to the extra bytes. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.resize_mut(4).clone_from_slice(b"test"); + /// ``` + pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { + let size = self.size; + self.resize(size + n); + unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } + } + + /// Append a slice at the end of this CryptoVec. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"test"); + /// ``` + pub fn extend(&mut self, s: &[u8]) { + let size = self.size; + self.resize(size + s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); + } + } + + /// Create a `CryptoVec` from a slice + /// + /// ``` + /// russh_cryptovec::CryptoVec::from_slice(b"test"); + /// ``` + pub fn from_slice(s: &[u8]) -> CryptoVec { + let mut v = CryptoVec::new(); + v.resize(s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len()); + } + v + } +} + +impl Clone for CryptoVec { + fn clone(&self) -> Self { + let mut v = Self::new(); + v.extend(self); + v + } +} + +// Drop implementation +impl Drop for CryptoVec { + fn drop(&mut self) { + if self.capacity > 0 { + unsafe { + for i in 0..self.size { + std::ptr::write_volatile(self.p.add(i), 0); + } + let _ = platform::munlock(self.p, self.capacity); + let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); + std::alloc::dealloc(self.p, layout); + } + } + } +} + +#[cfg(test)] +mod test { + use super::CryptoVec; + + #[test] + fn test_new() { + let crypto_vec = CryptoVec::new(); + assert_eq!(crypto_vec.size, 0); + assert_eq!(crypto_vec.capacity, 0); + } + + #[test] + fn test_resize_expand() { + let mut crypto_vec = CryptoVec::new_zeroed(5); + crypto_vec.resize(10); + assert_eq!(crypto_vec.size, 10); + assert!(crypto_vec.capacity >= 10); + assert!(crypto_vec.iter().skip(5).all(|&x| x == 0)); // Ensure newly added elements are zeroed + } + + #[test] + fn test_resize_shrink() { + let mut crypto_vec = CryptoVec::new_zeroed(10); + crypto_vec.resize(5); + assert_eq!(crypto_vec.size, 5); + // Ensure shrinking keeps the previous elements intact + assert_eq!(crypto_vec.len(), 5); + } + + #[test] + fn test_push() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.push(1); + crypto_vec.push(2); + assert_eq!(crypto_vec.size, 2); + assert_eq!(crypto_vec[0], 1); + assert_eq!(crypto_vec[1], 2); + } + + #[test] + fn test_write_trait() { + use std::io::Write; + + let mut crypto_vec = CryptoVec::new(); + let bytes_written = crypto_vec.write(&[1, 2, 3]).unwrap(); + assert_eq!(bytes_written, 3); + assert_eq!(crypto_vec.size, 3); + assert_eq!(crypto_vec.as_ref(), &[1, 2, 3]); + } + + #[test] + fn test_as_ref_as_mut() { + let mut crypto_vec = CryptoVec::new_zeroed(5); + let slice_ref: &[u8] = crypto_vec.as_ref(); + assert_eq!(slice_ref.len(), 5); + let slice_mut: &mut [u8] = crypto_vec.as_mut(); + slice_mut[0] = 1; + assert_eq!(crypto_vec[0], 1); + } + + #[test] + fn test_from_string() { + let input = String::from("hello"); + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), b"hello"); + } + + #[test] + fn test_from_str() { + let input = "hello"; + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), b"hello"); + } + + #[test] + fn test_from_byte_slice() { + let input = b"hello".as_slice(); + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), b"hello"); + } + + #[test] + fn test_from_vec() { + let input = vec![1, 2, 3, 4]; + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), &[1, 2, 3, 4]); + } + + #[test] + fn test_index() { + let crypto_vec = CryptoVec::from(vec![1, 2, 3, 4, 5]); + assert_eq!(crypto_vec[0], 1); + assert_eq!(crypto_vec[4], 5); + assert_eq!(&crypto_vec[1..3], &[2, 3]); + } + + #[test] + fn test_drop() { + let mut crypto_vec = CryptoVec::new_zeroed(10); + // Ensure vector is filled with non-zero data + crypto_vec.extend(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + drop(crypto_vec); + + // Check that memory zeroing was done during the drop + // This part is more difficult to test directly since it involves + // private memory management. However, with Rust's unsafe features, + // it may be checked using tools like Valgrind or manual inspection. + } + + #[test] + fn test_new_zeroed() { + let crypto_vec = CryptoVec::new_zeroed(10); + assert_eq!(crypto_vec.size, 10); + assert!(crypto_vec.capacity >= 10); + assert!(crypto_vec.iter().all(|&x| x == 0)); // Ensure all bytes are zeroed + } + + #[test] + fn test_clear() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"blabla"); + crypto_vec.clear(); + assert!(crypto_vec.is_empty()); + } + + #[test] + fn test_extend() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"test"); + assert_eq!(crypto_vec.as_ref(), b"test"); + } + + #[test] + fn test_write_all_from() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"blabla"); + + let mut output: Vec = Vec::new(); + let written_size = crypto_vec.write_all_from(0, &mut output).unwrap(); + assert_eq!(written_size, 6); // "blabla" has 6 bytes + assert_eq!(output, b"blabla"); + } + + #[test] + fn test_resize_mut() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.resize_mut(4).clone_from_slice(b"test"); + assert_eq!(crypto_vec.as_ref(), b"test"); + } + + // DocTests cannot be run on with wasm_bindgen_test + #[cfg(target_arch = "wasm32")] + mod wasm32 { + use wasm_bindgen_test::wasm_bindgen_test; + + use super::*; + + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + #[wasm_bindgen_test] + fn test_push_u32_be() { + let mut crypto_vec = CryptoVec::new(); + let value = 43554u32; + crypto_vec.push_u32_be(value); + assert_eq!(crypto_vec.len(), 4); // u32 is 4 bytes long + assert_eq!(crypto_vec.read_u32_be(0), value); + } + + #[wasm_bindgen_test] + fn test_read_u32_be() { + let mut crypto_vec = CryptoVec::new(); + let value = 99485710u32; + crypto_vec.push_u32_be(value); + assert_eq!(crypto_vec.read_u32_be(0), value); + } + } +} diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 8ecd1f0d..c1f4f778 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -4,6 +4,7 @@ clippy::indexing_slicing, clippy::panic )] + // Copyright 2016 Pierre-Étienne Meunier // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,414 +19,13 @@ // See the License for the specific language governing permissions and // limitations under the License. // -use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; - -use libc::c_void; -#[cfg(not(windows))] -use libc::size_t; - -/// A buffer which zeroes its memory on `.clear()`, `.resize()` and -/// reallocations, to avoid copying secrets around. -#[derive(Debug)] -pub struct CryptoVec { - p: *mut u8, - size: usize, - capacity: usize, -} - -impl Unpin for CryptoVec {} - -unsafe impl Send for CryptoVec {} -unsafe impl Sync for CryptoVec {} - -impl AsRef<[u8]> for CryptoVec { - fn as_ref(&self) -> &[u8] { - self.deref() - } -} -impl AsMut<[u8]> for CryptoVec { - fn as_mut(&mut self) -> &mut [u8] { - self.deref_mut() - } -} -impl Deref for CryptoVec { - type Target = [u8]; - fn deref(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.p, self.size) } - } -} -impl DerefMut for CryptoVec { - fn deref_mut(&mut self) -> &mut [u8] { - unsafe { std::slice::from_raw_parts_mut(self.p, self.size) } - } -} - -impl From for CryptoVec { - fn from(e: String) -> Self { - CryptoVec::from(e.into_bytes()) - } -} - -impl From> for CryptoVec { - fn from(e: Vec) -> Self { - let mut c = CryptoVec::new_zeroed(e.len()); - c.clone_from_slice(&e[..]); - c - } -} - -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: RangeFrom) -> &[u8] { - self.deref().index(index) - } -} -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: RangeTo) -> &[u8] { - self.deref().index(index) - } -} -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: Range) -> &[u8] { - self.deref().index(index) - } -} -impl Index for CryptoVec { - type Output = [u8]; - fn index(&self, _: RangeFull) -> &[u8] { - self.deref() - } -} -impl IndexMut for CryptoVec { - fn index_mut(&mut self, _: RangeFull) -> &mut [u8] { - self.deref_mut() - } -} - -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: RangeFrom) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: RangeTo) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: Range) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} - -impl Index for CryptoVec { - type Output = u8; - fn index(&self, index: usize) -> &u8 { - self.deref().index(index) - } -} - -impl std::io::Write for CryptoVec { - fn write(&mut self, buf: &[u8]) -> Result { - self.extend(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> Result<(), std::io::Error> { - Ok(()) - } -} - -impl Default for CryptoVec { - fn default() -> Self { - CryptoVec { - p: std::ptr::NonNull::dangling().as_ptr(), - size: 0, - capacity: 0, - } - } -} - -#[cfg(not(windows))] -unsafe fn mlock(ptr: *const u8, len: usize) { - libc::mlock(ptr as *const c_void, len as size_t); -} -#[cfg(not(windows))] -unsafe fn munlock(ptr: *const u8, len: usize) { - libc::munlock(ptr as *const c_void, len as size_t); -} - -#[cfg(windows)] -use winapi::shared::basetsd::SIZE_T; -#[cfg(windows)] -use winapi::shared::minwindef::LPVOID; -#[cfg(windows)] -use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; -#[cfg(windows)] -unsafe fn mlock(ptr: *const u8, len: usize) { - VirtualLock(ptr as LPVOID, len as SIZE_T); -} -#[cfg(windows)] -unsafe fn munlock(ptr: *const u8, len: usize) { - VirtualUnlock(ptr as LPVOID, len as SIZE_T); -} - -impl Clone for CryptoVec { - fn clone(&self) -> Self { - let mut v = Self::new(); - v.extend(self); - v - } -} - -impl CryptoVec { - /// Creates a new `CryptoVec`. - pub fn new() -> CryptoVec { - CryptoVec::default() - } - - /// Creates a new `CryptoVec` with `n` zeros. - pub fn new_zeroed(size: usize) -> CryptoVec { - unsafe { - let capacity = size.next_power_of_two(); - let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); - let p = std::alloc::alloc_zeroed(layout); - mlock(p, capacity); - CryptoVec { p, capacity, size } - } - } - - /// Creates a new `CryptoVec` with capacity `capacity`. - pub fn with_capacity(capacity: usize) -> CryptoVec { - unsafe { - let capacity = capacity.next_power_of_two(); - let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); - let p = std::alloc::alloc_zeroed(layout); - mlock(p, capacity); - CryptoVec { - p, - capacity, - size: 0, - } - } - } - - /// Length of this `CryptoVec`. - /// - /// ``` - /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0) - /// ``` - pub fn len(&self) -> usize { - self.size - } - - /// Returns `true` if and only if this CryptoVec is empty. - /// - /// ``` - /// assert!(russh_cryptovec::CryptoVec::new().is_empty()) - /// ``` - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Resize this CryptoVec, appending zeros at the end. This may - /// perform at most one reallocation, overwriting the previous - /// version with zeros. - pub fn resize(&mut self, size: usize) { - if size <= self.capacity && size > self.size { - // If this is an expansion, just resize. - self.size = size - } else if size <= self.size { - // If this is a truncation, resize and erase the extra memory. - unsafe { - libc::memset(self.p.add(size) as *mut c_void, 0, self.size - size); - } - self.size = size; - } else { - // realloc ! and erase the previous memory. - unsafe { - let next_capacity = size.next_power_of_two(); - let old_ptr = self.p; - let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); - self.p = std::alloc::alloc_zeroed(next_layout); - mlock(self.p, next_capacity); - - if self.capacity > 0 { - std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); - for i in 0..self.size { - std::ptr::write_volatile(old_ptr.add(i), 0) - } - munlock(old_ptr, self.capacity); - let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); - std::alloc::dealloc(old_ptr, layout); - } - - if self.p.is_null() { - #[allow(clippy::panic)] - { - panic!("Realloc failed, pointer = {:?} {:?}", self, size) - } - } else { - self.capacity = next_capacity; - self.size = size; - } - } - } - } - - /// Clear this CryptoVec (retaining the memory). - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"blabla"); - /// v.clear(); - /// assert!(v.is_empty()) - /// ``` - pub fn clear(&mut self) { - self.resize(0); - } - - /// Append a new byte at the end of this CryptoVec. - pub fn push(&mut self, s: u8) { - let size = self.size; - self.resize(size + 1); - unsafe { *self.p.add(size) = s } - } - - /// Append a new u32, big endian-encoded, at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 43554; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn push_u32_be(&mut self, s: u32) { - let s = s.to_be(); - let x: [u8; 4] = s.to_ne_bytes(); - self.extend(&x) - } - - /// Read a big endian-encoded u32 from this CryptoVec, with the - /// first byte at position `i`. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 99485710; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn read_u32_be(&self, i: usize) -> u32 { - assert!(i + 4 <= self.size); - let mut x: u32 = 0; - unsafe { - libc::memcpy( - (&mut x) as *mut u32 as *mut c_void, - self.p.add(i) as *const c_void, - 4, - ); - } - u32::from_be(x) - } - - /// Read `n_bytes` from `r`, and append them at the end of this - /// `CryptoVec`. Returns the number of bytes read (and appended). - pub fn read( - &mut self, - n_bytes: usize, - mut r: R, - ) -> Result { - let cur_size = self.size; - self.resize(cur_size + n_bytes); - let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) }; - // Resize the buffer to its appropriate size. - match r.read(s) { - Ok(n) => { - self.resize(cur_size + n); - Ok(n) - } - Err(e) => { - self.resize(cur_size); - Err(e) - } - } - } - - /// Write all this CryptoVec to the provided `Write`. Returns the - /// number of bytes actually written. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"blabla"); - /// let mut s = std::io::stdout(); - /// v.write_all_from(0, &mut s).unwrap(); - /// ``` - pub fn write_all_from( - &self, - offset: usize, - mut w: W, - ) -> Result { - assert!(offset < self.size); - // if we're past this point, self.p cannot be null. - unsafe { - let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset); - w.write(s) - } - } - - /// Resize this CryptoVec, returning a mutable borrow to the extra bytes. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.resize_mut(4).clone_from_slice(b"test"); - /// ``` - pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { - let size = self.size; - self.resize(size + n); - unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } - } - /// Append a slice at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"test"); - /// ``` - pub fn extend(&mut self, s: &[u8]) { - let size = self.size; - self.resize(size + s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); - } - } +// Re-export CryptoVec from the cryptovec module +mod cryptovec; +pub use cryptovec::CryptoVec; - /// Create a `CryptoVec` from a slice - /// - /// ``` - /// russh_cryptovec::CryptoVec::from_slice(b"test"); - /// ``` - pub fn from_slice(s: &[u8]) -> CryptoVec { - let mut v = CryptoVec::new(); - v.resize(s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len()); - } - v - } -} +// Platform-specific modules +mod platform; -impl Drop for CryptoVec { - fn drop(&mut self) { - if self.capacity > 0 { - unsafe { - for i in 0..self.size { - std::ptr::write_volatile(self.p.add(i), 0) - } - munlock(self.p, self.capacity); - let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); - std::alloc::dealloc(self.p, layout); - } - } - } -} +#[cfg(feature = "ssh-encoding")] +mod ssh; diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs new file mode 100644 index 00000000..f6b0c87c --- /dev/null +++ b/cryptovec/src/platform/mod.rs @@ -0,0 +1,77 @@ +#[cfg(windows)] +mod windows; + +#[cfg(not(windows))] +#[cfg(not(target_arch = "wasm32"))] +mod unix; + +#[cfg(target_arch = "wasm32")] +mod wasm; + +// Re-export functions based on the platform +#[cfg(not(windows))] +#[cfg(not(target_arch = "wasm32"))] +pub use unix::{memset, mlock, munlock}; +#[cfg(target_arch = "wasm32")] +pub use wasm::{memset, mlock, munlock}; +#[cfg(windows)] +pub use windows::{memset, mlock, munlock}; + +#[cfg(not(target_arch = "wasm32"))] +mod error { + use std::error::Error; + use std::fmt::Display; + use std::sync::atomic::{AtomicBool, Ordering}; + + use log::warn; + + #[derive(Debug)] + pub struct MemoryLockError { + message: String, + } + + impl MemoryLockError { + pub fn new(message: String) -> Self { + let warning_previously_shown = MLOCK_WARNING_SHOWN.swap(true, Ordering::Relaxed); + if !warning_previously_shown { + warn!("Security warning: OS has failed to lock/unlock memory for a cryptographic buffer: {}", message); + #[cfg(unix)] + warn!("You might need to increase the RLIMIT_MEMLOCK limit."); + warn!("This warning will only be shown once."); + } + Self { message } + } + } + + static MLOCK_WARNING_SHOWN: AtomicBool = AtomicBool::new(false); + + impl Display for MemoryLockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "failed to lock/unlock memory: {}", self.message) + } + } + + impl Error for MemoryLockError {} +} + +#[cfg(not(target_arch = "wasm32"))] +pub use error::MemoryLockError; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memset() { + let mut buf = vec![0u8; 10]; + memset(buf.as_mut_ptr(), 0xff, buf.len()); + assert_eq!(buf, vec![0xff; 10]); + } + + #[test] + fn test_memset_partial() { + let mut buf = vec![0u8; 10]; + memset(buf.as_mut_ptr(), 0xff, 5); + assert_eq!(buf, [0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0]); + } +} diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs new file mode 100644 index 00000000..c7596368 --- /dev/null +++ b/cryptovec/src/platform/unix.rs @@ -0,0 +1,34 @@ +use std::ffi::c_void; +use std::ptr::NonNull; + +use nix::errno::Errno; + +use super::MemoryLockError; + +/// Unlock memory on drop for Unix-based systems. +pub fn munlock(ptr: *const u8, len: usize) -> Result<(), MemoryLockError> { + unsafe { + Errno::clear(); + let ptr = NonNull::new_unchecked(ptr as *mut c_void); + nix::sys::mman::munlock(ptr, len).map_err(|e| { + MemoryLockError::new(format!("munlock: {} (0x{:x})", e.desc(), e as i32)) + })?; + } + Ok(()) +} + +pub fn mlock(ptr: *const u8, len: usize) -> Result<(), MemoryLockError> { + unsafe { + Errno::clear(); + let ptr = NonNull::new_unchecked(ptr as *mut c_void); + nix::sys::mman::mlock(ptr, len) + .map_err(|e| MemoryLockError::new(format!("mlock: {} (0x{:x})", e.desc(), e as i32)))?; + } + Ok(()) +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + unsafe { + nix::libc::memset(ptr as *mut c_void, value, size); + } +} diff --git a/cryptovec/src/platform/wasm.rs b/cryptovec/src/platform/wasm.rs new file mode 100644 index 00000000..55402df5 --- /dev/null +++ b/cryptovec/src/platform/wasm.rs @@ -0,0 +1,18 @@ +use std::convert::Infallible; + +// WASM does not support synchronization primitives +pub fn munlock(_ptr: *const u8, _len: usize) -> Result<(), Infallible> { + // No-op + Ok(()) +} + +pub fn mlock(_ptr: *const u8, _len: usize) -> Result<(), Infallible> { + Ok(()) +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + let byte_value = value as u8; // Extract the least significant byte directly + unsafe { + std::ptr::write_bytes(ptr, byte_value, size); + } +} diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs new file mode 100644 index 00000000..3f0f162d --- /dev/null +++ b/cryptovec/src/platform/windows.rs @@ -0,0 +1,111 @@ +use std::collections::btree_map::Entry; +use std::collections::BTreeMap; +use std::ffi::c_void; +use std::sync::{Mutex, OnceLock}; + +use winapi::shared::basetsd::SIZE_T; +use winapi::shared::minwindef::LPVOID; +use winapi::um::errhandlingapi::GetLastError; +use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; +use winapi::um::sysinfoapi::{GetNativeSystemInfo, SYSTEM_INFO}; + +use super::MemoryLockError; + +// To correctly lock/unlock memory, we need to know the pagesize: +static PAGE_SIZE: OnceLock = OnceLock::new(); +// Store refcounters for all locked pages, since Windows doesn't handle that for us: +static LOCKED_PAGES: Mutex> = Mutex::new(BTreeMap::new()); + +/// Unlock memory on drop for Windows. +pub fn munlock(ptr: *const u8, len: usize) -> Result<(), MemoryLockError> { + let page_indices = get_page_indices(ptr, len); + let mut locked_pages = LOCKED_PAGES + .lock() + .map_err(|e| MemoryLockError::new(format!("Accessing PageLocks failed: {e}")))?; + for page_idx in page_indices { + match locked_pages.entry(page_idx) { + Entry::Occupied(mut lock_counter) => { + let lock_counter_val = lock_counter.get_mut(); + *lock_counter_val -= 1; + if *lock_counter_val == 0 { + lock_counter.remove(); + unlock_page(page_idx)?; + } + } + Entry::Vacant(_) => { + return Err(MemoryLockError::new( + "Tried to unlock pointer from non-locked page!".into(), + )); + } + } + } + Ok(()) +} + +fn unlock_page(page_idx: usize) -> Result<(), MemoryLockError> { + unsafe { + if VirtualUnlock((page_idx * get_page_size()) as LPVOID, 1 as SIZE_T) == 0 { + // codes can be looked up at https://learn.microsoft.com/en-us/windows/win32/debug/system-error-codes + let errorcode = GetLastError(); + return Err(MemoryLockError::new(format!( + "VirtualUnlock: 0x{errorcode:x}" + ))); + } + } + Ok(()) +} + +pub fn mlock(ptr: *const u8, len: usize) -> Result<(), MemoryLockError> { + let page_indices = get_page_indices(ptr, len); + let mut locked_pages = LOCKED_PAGES + .lock() + .map_err(|e| MemoryLockError::new(format!("Accessing PageLocks failed: {e}")))?; + for page_idx in page_indices { + match locked_pages.entry(page_idx) { + Entry::Occupied(mut lock_counter) => { + let lock_counter_val = lock_counter.get_mut(); + *lock_counter_val += 1; + } + Entry::Vacant(lock_counter) => { + lock_page(page_idx)?; + lock_counter.insert(1); + } + } + } + Ok(()) +} + +fn lock_page(page_idx: usize) -> Result<(), MemoryLockError> { + unsafe { + if VirtualLock((page_idx * get_page_size()) as LPVOID, 1 as SIZE_T) == 0 { + let errorcode = GetLastError(); + return Err(MemoryLockError::new(format!( + "VirtualLock: 0x{errorcode:x}" + ))); + } + } + Ok(()) +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + unsafe { + libc::memset(ptr as *mut c_void, value, size); + } +} + +fn get_page_size() -> usize { + *PAGE_SIZE.get_or_init(|| { + let mut sys_info = SYSTEM_INFO::default(); + unsafe { + GetNativeSystemInfo(&mut sys_info); + } + sys_info.dwPageSize as usize + }) +} + +fn get_page_indices(ptr: *const u8, len: usize) -> std::ops::Range { + let page_size = get_page_size(); + let first_page = ptr as usize / page_size; + let page_count = (len + page_size - 1) / page_size; + first_page..(first_page + page_count) +} diff --git a/cryptovec/src/ssh.rs b/cryptovec/src/ssh.rs new file mode 100644 index 00000000..846dd793 --- /dev/null +++ b/cryptovec/src/ssh.rs @@ -0,0 +1,20 @@ +use ssh_encoding::{Reader, Result, Writer}; + +use crate::CryptoVec; + +impl Reader for CryptoVec { + fn read<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8]> { + (&self[..]).read(out) + } + + fn remaining_len(&self) -> usize { + self.len() + } +} + +impl Writer for CryptoVec { + fn write(&mut self, bytes: &[u8]) -> Result<()> { + self.extend(bytes); + Ok(()) + } +} diff --git a/pageant/Cargo.toml b/pageant/Cargo.toml new file mode 100644 index 00000000..d1abd864 --- /dev/null +++ b/pageant/Cargo.toml @@ -0,0 +1,26 @@ +[package] +authors = ["Eugene "] +description = "Pageant SSH agent transport client." +documentation = "https://docs.rs/pageant" +edition = "2021" +license = "Apache-2.0" +name = "pageant" +repository = "https://github.com/warp-tech/russh" +version = "0.0.3" +rust-version = "1.75" + +[target.'cfg(windows)'.dependencies] +futures.workspace = true +thiserror.workspace = true +rand.workspace = true +log.workspace = true +tokio = { workspace = true, features = ["io-util", "rt"] } +bytes.workspace = true +delegate.workspace = true +windows = { version = "0.58", features = [ + "Win32_UI_WindowsAndMessaging", + "Win32_System_Memory", + "Win32_Security", + "Win32_System_Threading", + "Win32_System_DataExchange", +] } diff --git a/pageant/src/lib.rs b/pageant/src/lib.rs new file mode 100644 index 00000000..5648077c --- /dev/null +++ b/pageant/src/lib.rs @@ -0,0 +1,18 @@ +//! # Pageant SSH agent transport protocol implementation +//! +//! This crate provides a [PageantStream] type that implements [AsyncRead] and [AsyncWrite] traits and can be used to talk to a running Pageant instance. +//! +//! This crate only implements the transport, not the actual SSH agent protocol. + +#![deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic +)] + +#[cfg(windows)] +mod pageant_impl; + +#[cfg(windows)] +pub use pageant_impl::*; diff --git a/pageant/src/pageant_impl.rs b/pageant/src/pageant_impl.rs new file mode 100644 index 00000000..5651a8f0 --- /dev/null +++ b/pageant/src/pageant_impl.rs @@ -0,0 +1,300 @@ +use std::io::IoSlice; +use std::mem::size_of; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::BytesMut; +use delegate::delegate; +use log::debug; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf}; +use windows::core::HSTRING; +use windows::Win32::Foundation::{CloseHandle, HANDLE, HWND, INVALID_HANDLE_VALUE, LPARAM, WPARAM}; +use windows::Win32::Security::{ + GetTokenInformation, InitializeSecurityDescriptor, SetSecurityDescriptorOwner, TokenUser, + PSECURITY_DESCRIPTOR, SECURITY_ATTRIBUTES, SECURITY_DESCRIPTOR, TOKEN_QUERY, TOKEN_USER, +}; +use windows::Win32::System::DataExchange::COPYDATASTRUCT; +use windows::Win32::System::Memory::{ + CreateFileMappingW, MapViewOfFile, UnmapViewOfFile, FILE_MAP_WRITE, MEMORY_MAPPED_VIEW_ADDRESS, + PAGE_READWRITE, +}; +use windows::Win32::System::Threading::{GetCurrentProcess, OpenProcessToken}; +use windows::Win32::UI::WindowsAndMessaging::{FindWindowW, SendMessageA, WM_COPYDATA}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Pageant not found")] + NotFound, + + #[error("Buffer overflow")] + Overflow, + + #[error("No response from Pageant")] + NoResponse, + + #[error(transparent)] + WindowsError(#[from] windows::core::Error), +} + +impl Error { + fn from_win32() -> Self { + Self::WindowsError(windows::core::Error::from_win32()) + } +} + +/// Pageant transport stream. Implements [AsyncRead] and [AsyncWrite]. +/// +/// The stream has a unique cookie and requests made in the same stream are considered the same "session". +pub struct PageantStream { + stream: DuplexStream, +} + +impl PageantStream { + pub fn new() -> Self { + let (one, mut two) = tokio::io::duplex(_AGENT_MAX_MSGLEN * 100); + + let cookie = rand::random::().to_string(); + tokio::spawn(async move { + let mut buf = BytesMut::new(); + while let Ok(n) = two.read_buf(&mut buf).await { + if n == 0 { + break; + } + let msg = buf.split().freeze(); + let Ok(response) = query_pageant_direct(cookie.clone(), &msg).map_err(|e| { + debug!("Pageant query failed: {:?}", e); + e + }) else { + break; + }; + two.write_all(&response).await? + } + std::io::Result::Ok(()) + }); + + Self { stream: one } + } +} + +impl Default for PageantStream { + fn default() -> Self { + Self::new() + } +} + +impl AsyncRead for PageantStream { + delegate! { + to Pin::new(&mut self.stream) { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll>; + + } + } +} + +impl AsyncWrite for PageantStream { + delegate! { + to Pin::new(&mut self.stream) { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll>; + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + } + + to Pin::new(&self.stream) { + fn is_write_vectored(&self) -> bool; + } + } +} + +struct MemoryMap { + filemap: HANDLE, + view: MEMORY_MAPPED_VIEW_ADDRESS, + length: usize, + pos: usize, +} + +impl MemoryMap { + fn new( + name: String, + length: usize, + security_attributes: Option, + ) -> Result { + let filemap = unsafe { + CreateFileMappingW( + INVALID_HANDLE_VALUE, + security_attributes.map(|sa| &sa as *const _), + PAGE_READWRITE, + 0, + length as u32, + &HSTRING::from(name.clone()), + ) + }?; + if filemap.is_invalid() { + return Err(Error::from_win32()); + } + let view = unsafe { MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0) }; + Ok(Self { + filemap, + view, + length, + pos: 0, + }) + } + + fn seek(&mut self, pos: usize) { + self.pos = pos; + } + + fn write(&mut self, data: &[u8]) -> Result<(), Error> { + if self.pos + data.len() > self.length { + return Err(Error::Overflow); + } + + if data.is_empty() { + return Ok(()); + } + + unsafe { + #[allow(clippy::indexing_slicing)] // length checked + std::ptr::copy_nonoverlapping( + &data[0] as *const u8, + self.view.Value.add(self.pos) as *mut u8, + data.len(), + ); + } + self.pos += data.len(); + Ok(()) + } + + fn read(&mut self, n: usize) -> Vec { + let out = vec![0; n]; + unsafe { + std::ptr::copy_nonoverlapping( + self.view.Value.add(self.pos) as *const u8, + out.as_ptr() as *mut u8, + n, + ); + } + self.pos += n; + out + } +} + +impl Drop for MemoryMap { + fn drop(&mut self) { + unsafe { + let _ = UnmapViewOfFile(self.view); + let _ = CloseHandle(self.filemap); + } + } +} + +fn find_pageant_window() -> Result { + let w = unsafe { FindWindowW(&HSTRING::from("Pageant"), &HSTRING::from("Pageant")) }?; + if w.is_invalid() { + return Err(Error::NotFound); + } + Ok(w) +} + +const _AGENT_COPYDATA_ID: u64 = 0x804E50BA; +const _AGENT_MAX_MSGLEN: usize = 8192; + +pub fn is_pageant_running() -> bool { + find_pageant_window().is_ok() +} + +/// Send a one-off query to Pageant and return a response. +pub fn query_pageant_direct(cookie: String, msg: &[u8]) -> Result, Error> { + let hwnd = find_pageant_window()?; + let map_name = format!("PageantRequest{cookie}"); + + let user = unsafe { + let mut process_token = HANDLE::default(); + OpenProcessToken( + GetCurrentProcess(), + TOKEN_QUERY, + &mut process_token as *mut _, + )?; + + let mut info_size = 0; + let _ = GetTokenInformation(process_token, TokenUser, None, 0, &mut info_size); + + let mut buffer = vec![0; info_size as usize]; + GetTokenInformation( + process_token, + TokenUser, + Some(buffer.as_mut_ptr() as *mut _), + buffer.len() as u32, + &mut info_size, + )?; + let user: TOKEN_USER = *(buffer.as_ptr() as *const _); + let _ = CloseHandle(process_token); + user + }; + + let mut sd = SECURITY_DESCRIPTOR::default(); + let sa = SECURITY_ATTRIBUTES { + lpSecurityDescriptor: &mut sd as *mut _ as *mut _, + bInheritHandle: true.into(), + ..Default::default() + }; + + let psd = PSECURITY_DESCRIPTOR(&mut sd as *mut _ as *mut _); + + unsafe { + InitializeSecurityDescriptor(psd, 1)?; + SetSecurityDescriptorOwner(psd, user.User.Sid, false)?; + } + + let mut map: MemoryMap = MemoryMap::new(map_name.clone(), _AGENT_MAX_MSGLEN, Some(sa))?; + map.write(msg)?; + + let mut char_buffer = map_name.as_bytes().to_vec(); + char_buffer.push(0); + let cds = COPYDATASTRUCT { + dwData: _AGENT_COPYDATA_ID as usize, + cbData: char_buffer.len() as u32, + lpData: char_buffer.as_ptr() as *mut _, + }; + + let response = unsafe { + SendMessageA( + hwnd, + WM_COPYDATA, + WPARAM(size_of::()), + LPARAM(&cds as *const _ as isize), + ) + }; + + if response.0 == 0 { + return Err(Error::NoResponse); + } + + map.seek(0); + let mut buf = map.read(4); + if buf.len() < 4 { + return Err(Error::NoResponse); + } + #[allow(clippy::indexing_slicing)] // length checked + let size = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + buf.extend(map.read(size)); + + Ok(buf) +} diff --git a/russh-config/Cargo.toml b/russh-config/Cargo.toml index 1ad2846b..6fa5c81e 100644 --- a/russh-config/Cargo.toml +++ b/russh-config/Cargo.toml @@ -2,17 +2,19 @@ authors = ["Pierre-Étienne Meunier "] description = "Utilities to parse .ssh/config files, including helpers to implement ProxyCommand in Russh." documentation = "https://docs.rs/russh-config" -edition = "2018" +edition = "2021" include = ["Cargo.toml", "src/lib.rs", "src/proxy.rs"] license = "Apache-2.0" name = "russh-config" repository = "https://github.com/warp-tech/russh" -version = "0.7.0" +version = "0.50.0" +rust-version = "1.75" [dependencies] -dirs-next = "2.0" -futures = "0.3" -log = "0.4" -thiserror = "1.0" -tokio = {version = "1.0", features = ["io-util", "net", "macros", "process"]} +home.workspace = true +futures.workspace = true +globset = "0.3" +log.workspace = true +thiserror.workspace = true +tokio = { workspace = true, features = ["io-util", "net", "macros", "process"] } whoami = "1.2" diff --git a/russh-config/src/lib.rs b/russh-config/src/lib.rs index cdf4d95a..a269d364 100644 --- a/russh-config/src/lib.rs +++ b/russh-config/src/lib.rs @@ -8,6 +8,7 @@ use std::io::Read; use std::net::ToSocketAddrs; use std::path::Path; +use globset::Glob; use log::debug; use thiserror::*; @@ -34,7 +35,10 @@ pub struct Config { pub port: u16, pub identity_file: Option, pub proxy_command: Option, + pub proxy_jump: Option, pub add_keys_to_agent: AddKeysToAgent, + pub user_known_hosts_file: Option, + pub strict_host_key_checking: bool, } impl Config { @@ -45,22 +49,32 @@ impl Config { port: 22, identity_file: None, proxy_command: None, + proxy_jump: None, add_keys_to_agent: AddKeysToAgent::default(), + user_known_hosts_file: None, + strict_host_key_checking: true, } } } impl Config { - fn update_proxy_command(&mut self) { - if let Some(ref mut prox) = self.proxy_command { - *prox = prox.replace("%h", &self.host_name); - *prox = prox.replace("%p", &format!("{}", self.port)); - } + // Look for any of the ssh_config(5) percent-style tokens and expand them + // based on current data in the struct, returning a new String. This function + // can be employed late/lazy eg just before establishing a stream using ProxyCommand + // but also can be used to modify Hostname as config parse time + fn expand_tokens(&self, original: &str) -> String { + let mut string = original.to_string(); + string = string.replace("%u", &self.user); + string = string.replace("%h", &self.host_name); // remote hostname (from context "host") + string = string.replace("%H", &self.host_name); // remote hostname (from context "host") + string = string.replace("%p", &format!("{}", self.port)); // original typed hostname (from context "host") + string = string.replace("%%", "%"); + string } - pub async fn stream(&mut self) -> Result { - self.update_proxy_command(); + pub async fn stream(&self) -> Result { if let Some(ref proxy_command) = self.proxy_command { + let proxy_command = self.expand_tokens(proxy_command); let cmd: Vec<&str> = proxy_command.split(' ').collect(); Stream::proxy_command(cmd.first().unwrap_or(&""), cmd.get(1..).unwrap_or(&[])) .await @@ -76,7 +90,7 @@ impl Config { } pub fn parse_home(host: &str) -> Result { - let mut home = if let Some(home) = dirs_next::home_dir() { + let mut home = if let Some(home) = home::home_dir() { home } else { return Err(Error::NoHome); @@ -93,86 +107,199 @@ pub fn parse_path>(path: P, host: &str) -> Result parse(&s, host) } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub enum AddKeysToAgent { Yes, Confirm, Ask, + #[default] No, } -impl Default for AddKeysToAgent { - fn default() -> Self { - AddKeysToAgent::No - } -} - pub fn parse(file: &str, host: &str) -> Result { - let mut config: Option = None; + let mut config = Config::default(host); + let mut matches_current = false; for line in file.lines() { - let line = line.trim(); - if let Some(n) = line.find(' ') { - let (key, value) = line.split_at(n); + let tokens = line.trim().splitn(2, ' ').collect::>(); + if tokens.len() == 2 { + let (key, value) = (tokens.first().unwrap_or(&""), tokens.get(1).unwrap_or(&"")); let lower = key.to_lowercase(); - if let Some(ref mut config) = config { + if lower.as_str() == "host" { + matches_current = value + .split_whitespace() + .any(|x| check_host_against_glob_pattern(host, x)); + } + if matches_current { match lower.as_str() { - "host" => break, "user" => { config.user.clear(); config.user.push_str(value.trim_start()); } - "hostname" => { - config.host_name.clear(); - config.host_name.push_str(value.trim_start()) - } + "hostname" => config.host_name = config.expand_tokens(value.trim_start()), "port" => { if let Ok(port) = value.trim_start().parse() { config.port = port } } "identityfile" => { - let id = value.trim_start(); - if id.starts_with("~/") { - if let Some(mut home) = dirs_next::home_dir() { - home.push(id.split_at(2).1); - config.identity_file = Some( - home.to_str() - .ok_or_else(|| { - std::io::Error::new( - std::io::ErrorKind::Other, - "Failed to convert home directory to string", - ) - })? - .to_string(), - ); - } else { - return Err(Error::NoHome); - } - } else { - config.identity_file = Some(id.to_string()) - } + config.identity_file = + Some(value.trim_start().strip_quotes().expand_home()?); } "proxycommand" => config.proxy_command = Some(value.trim_start().to_string()), + "proxyjump" => config.proxy_jump = Some(value.trim_start().to_string()), "addkeystoagent" => match value.to_lowercase().as_str() { "yes" => config.add_keys_to_agent = AddKeysToAgent::Yes, "confirm" => config.add_keys_to_agent = AddKeysToAgent::Confirm, "ask" => config.add_keys_to_agent = AddKeysToAgent::Ask, _ => config.add_keys_to_agent = AddKeysToAgent::No, }, + "userknownhostsfile" => { + config.user_known_hosts_file = + Some(value.trim_start().strip_quotes().expand_home()?); + } + "stricthostkeychecking" => match value.to_lowercase().as_str() { + "no" => config.strict_host_key_checking = false, + _ => config.strict_host_key_checking = true, + }, key => { debug!("{:?}", key); } } - } else if lower.as_str() == "host" && value.trim_start() == host { - let mut c = Config::default(host); - c.port = 22; - config = Some(c) } } } - if let Some(config) = config { - Ok(config) - } else { - Err(Error::HostNotFound) + Ok(config) +} + +fn check_host_against_glob_pattern(candidate: &str, glob_pattern: &str) -> bool { + match Glob::new(glob_pattern) { + Ok(glob) => glob.compile_matcher().is_match(candidate), + _ => false, + } +} + +trait SshConfigStrExt { + fn strip_quotes(&self) -> Self; + fn expand_home(&self) -> Result; +} + +impl SshConfigStrExt for &str { + fn strip_quotes(&self) -> Self { + if self.len() > 1 + && ((self.starts_with('\'') && self.ends_with('\'')) + || (self.starts_with('\"') && self.ends_with('\"'))) + { + #[allow(clippy::indexing_slicing)] // length checked + &self[1..self.len() - 1] + } else { + self + } + } + + fn expand_home(&self) -> Result { + if self.starts_with("~/") { + if let Some(mut home) = home::home_dir() { + home.push(self.split_at(2).1); + Ok(home + .to_str() + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to convert home directory to string", + ) + })? + .to_string()) + } else { + Err(Error::NoHome) + } + } else { + Ok(self.to_string()) + } + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::expect_used)] + use crate::{parse, AddKeysToAgent, Config, SshConfigStrExt}; + + #[test] + fn strip_quotes() { + let value = "'this is a test'"; + assert_eq!("this is a test", value.strip_quotes()); + let value = "\"this is a test\""; + assert_eq!("this is a test", value.strip_quotes()); + let value = "'this is a test\""; + assert_eq!("'this is a test\"", value.strip_quotes()); + let value = "'this is a test"; + assert_eq!("'this is a test", value.strip_quotes()); + let value = "this is a test'"; + assert_eq!("this is a test'", value.strip_quotes()); + let value = "this is a test"; + assert_eq!("this is a test", value.strip_quotes()); + let value = ""; + assert_eq!("", value.strip_quotes()); + let value = "'"; + assert_eq!("'", value.strip_quotes()); + let value = "''"; + assert_eq!("", value.strip_quotes()); + } + + #[test] + fn expand_home() { + let value = "~/some/folder".expand_home().expect("expand_home"); + assert_eq!( + format!( + "{}{}", + home::home_dir().expect("homedir").to_str().expect("to_str"), + "/some/folder" + ), + value + ); + } + + #[test] + fn default_config() { + let config: Config = Config::default("some_host"); + assert_eq!(whoami::username(), config.user); + assert_eq!("some_host", config.host_name); + assert_eq!(22, config.port); + assert_eq!(None, config.identity_file); + assert_eq!(None, config.proxy_command); + assert_eq!(None, config.proxy_jump); + assert_eq!(AddKeysToAgent::No, config.add_keys_to_agent); + assert_eq!(None, config.user_known_hosts_file); + assert!(config.strict_host_key_checking); + } + + #[test] + fn basic_config() { + let value = r"# +Host test_host + IdentityFile '~/.ssh/id_ed25519' + User trinity + Hostname foo.com + Port 23 + UserKnownHostsFile /some/special/host_file + StrictHostKeyChecking no +#"; + let identity_file = format!( + "{}{}", + home::home_dir().expect("homedir").to_str().expect("to_str"), + "/.ssh/id_ed25519" + ); + let config = parse(value, "test_host").expect("parse"); + assert_eq!("trinity", config.user); + assert_eq!("foo.com", config.host_name); + assert_eq!(23, config.port); + assert_eq!(Some(identity_file), config.identity_file); + assert_eq!(None, config.proxy_command); + assert_eq!(None, config.proxy_jump); + assert_eq!(AddKeysToAgent::No, config.add_keys_to_agent); + assert_eq!( + Some("/some/special/host_file"), + config.user_known_hosts_file.as_deref() + ); + assert!(!config.strict_host_key_checking); } } diff --git a/russh-keys/Cargo.toml b/russh-keys/Cargo.toml deleted file mode 100644 index fd5bcc80..00000000 --- a/russh-keys/Cargo.toml +++ /dev/null @@ -1,74 +0,0 @@ -[package] -authors = ["Pierre-Étienne Meunier "] -description = "Deal with SSH keys: load them, decrypt them, call an SSH agent." -documentation = "https://docs.rs/russh-keys" -edition = "2018" -homepage = "https://github.com/warp-tech/russh" -include = [ - "Cargo.toml", - "src/lib.rs", - "src/agent/mod.rs", - "src/agent/msg.rs", - "src/agent/server.rs", - "src/agent/client.rs", - "src/bcrypt_pbkdf.rs", - "src/blowfish.rs", - "src/encoding.rs", - "src/format/mod.rs", - "src/format/openssh.rs", - "src/format/pkcs5.rs", - "src/format/pkcs8.rs", - "src/key.rs", - "src/signature.rs", -] -keywords = ["ssh"] -license = "Apache-2.0" -name = "russh-keys" -repository = "https://github.com/warp-tech/russh" -version = "0.37.1" - -[dependencies] -aes = "0.8" -async-trait = "0.1.72" -bcrypt-pbkdf = "0.10" -bit-vec = "0.6" -cbc = "0.1" -ctr = "0.9" -block-padding = { version = "0.3", features = ["std"] } -byteorder = "1.4" -data-encoding = "2.3" -dirs = "5.0" -ed25519-dalek = { version= "2.0", features = ["rand_core"] } -futures = "0.3" -hmac = "0.12" -inout = { version = "0.1", features = ["std"] } -log = "0.4" -md5 = "0.7" -num-bigint = "0.4" -num-integer = "0.1" -openssl = { version = "0.10", optional = true } -pbkdf2 = "0.11" -rand = "0.7" -rand_core = { version = "0.6.4", features = ["std"] } -russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } -serde = { version = "1.0", features = ["derive"] } -sha2 = "0.10" -thiserror = "1.0" -tokio = { version = "1.17.0", features = [ - "io-util", - "rt-multi-thread", - "time", - "net", -] } -tokio-stream = { version = "0.1", features = ["net"] } -yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"] } - -[features] -vendored-openssl = ["openssl", "openssl/vendored"] - -[dev-dependencies] -env_logger = "0.10" -tempdir = "0.3" - -[package.metadata.docs.rs] -features = ["openssl"] diff --git a/russh-keys/src/agent/client.rs b/russh-keys/src/agent/client.rs deleted file mode 100644 index 1e730030..00000000 --- a/russh-keys/src/agent/client.rs +++ /dev/null @@ -1,531 +0,0 @@ -use std::convert::TryFrom; - -use byteorder::{BigEndian, ByteOrder}; -use log::{debug, info}; -use russh_cryptovec::CryptoVec; -use tokio; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -use super::{msg, Constraint}; -use crate::encoding::{Encoding, Reader}; -use crate::key::{PublicKey, SignatureHash}; -use crate::{key, Error}; - -/// SSH agent client. -pub struct AgentClient { - stream: S, - buf: CryptoVec, -} - -// https://tools.ietf.org/html/draft-miller-ssh-agent-00#section-4.1 -impl AgentClient { - /// Build a future that connects to an SSH agent via the provided - /// stream (on Unix, usually a Unix-domain socket). - pub fn connect(stream: S) -> Self { - AgentClient { - stream, - buf: CryptoVec::new(), - } - } -} - -#[cfg(unix)] -impl AgentClient { - /// Build a future that connects to an SSH agent via the provided - /// stream (on Unix, usually a Unix-domain socket). - pub async fn connect_uds>(path: P) -> Result { - let stream = tokio::net::UnixStream::connect(path).await?; - Ok(AgentClient { - stream, - buf: CryptoVec::new(), - }) - } - - /// Build a future that connects to an SSH agent via the provided - /// stream (on Unix, usually a Unix-domain socket). - pub async fn connect_env() -> Result { - let var = if let Ok(var) = std::env::var("SSH_AUTH_SOCK") { - var - } else { - return Err(Error::EnvVar("SSH_AUTH_SOCK")); - }; - match Self::connect_uds(var).await { - Err(Error::IO(io_err)) if io_err.kind() == std::io::ErrorKind::NotFound => { - Err(Error::BadAuthSock) - } - owise => owise, - } - } -} - -#[cfg(not(unix))] -impl AgentClient { - /// Build a future that connects to an SSH agent via the provided - /// stream (on Unix, usually a Unix-domain socket). - pub async fn connect_env() -> Result { - Err(Error::AgentFailure) - } -} - -impl AgentClient { - async fn read_response(&mut self) -> Result<(), Error> { - // Writing the message - self.stream.write_all(&self.buf).await?; - self.stream.flush().await?; - - // Reading the length - self.buf.clear(); - self.buf.resize(4); - self.stream.read_exact(&mut self.buf).await?; - - // Reading the rest of the buffer - let len = BigEndian::read_u32(&self.buf) as usize; - self.buf.clear(); - self.buf.resize(len); - self.stream.read_exact(&mut self.buf).await?; - - Ok(()) - } - - /// Send a key to the agent, with a (possibly empty) slice of - /// constraints to apply when using the key to sign. - pub async fn add_identity( - &mut self, - key: &key::KeyPair, - constraints: &[Constraint], - ) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - if constraints.is_empty() { - self.buf.push(msg::ADD_IDENTITY) - } else { - self.buf.push(msg::ADD_ID_CONSTRAINED) - } - match *key { - key::KeyPair::Ed25519(ref pair) => { - self.buf.extend_ssh_string(b"ssh-ed25519"); - self.buf.extend_ssh_string(pair.verifying_key().as_bytes()); - self.buf.push_u32_be(64); - self.buf.extend(pair.to_bytes().as_slice()); - self.buf.extend(pair.verifying_key().as_bytes()); - self.buf.extend_ssh_string(b""); - } - #[cfg(feature = "openssl")] - #[allow(clippy::unwrap_used)] // key is known to be private - key::KeyPair::RSA { ref key, .. } => { - self.buf.extend_ssh_string(b"ssh-rsa"); - self.buf.extend_ssh_mpint(&key.n().to_vec()); - self.buf.extend_ssh_mpint(&key.e().to_vec()); - self.buf.extend_ssh_mpint(&key.d().to_vec()); - if let Some(iqmp) = key.iqmp() { - self.buf.extend_ssh_mpint(&iqmp.to_vec()); - } else { - let mut ctx = openssl::bn::BigNumContext::new()?; - let mut iqmp = openssl::bn::BigNum::new()?; - iqmp.mod_inverse(key.p().unwrap(), key.q().unwrap(), &mut ctx)?; - self.buf.extend_ssh_mpint(&iqmp.to_vec()); - } - self.buf.extend_ssh_mpint(&key.p().unwrap().to_vec()); - self.buf.extend_ssh_mpint(&key.q().unwrap().to_vec()); - self.buf.extend_ssh_string(b""); - } - } - if !constraints.is_empty() { - self.buf.push_u32_be(constraints.len() as u32); - for cons in constraints { - match *cons { - Constraint::KeyLifetime { seconds } => { - self.buf.push(msg::CONSTRAIN_LIFETIME); - self.buf.push_u32_be(seconds) - } - Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), - Constraint::Extensions { - ref name, - ref details, - } => { - self.buf.push(msg::CONSTRAIN_EXTENSION); - self.buf.extend_ssh_string(name); - self.buf.extend_ssh_string(details); - } - } - } - } - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - - self.read_response().await?; - Ok(()) - } - - /// Add a smart card to the agent, with a (possibly empty) set of - /// constraints to apply when signing. - pub async fn add_smartcard_key( - &mut self, - id: &str, - pin: &[u8], - constraints: &[Constraint], - ) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - if constraints.is_empty() { - self.buf.push(msg::ADD_SMARTCARD_KEY) - } else { - self.buf.push(msg::ADD_SMARTCARD_KEY_CONSTRAINED) - } - self.buf.extend_ssh_string(id.as_bytes()); - self.buf.extend_ssh_string(pin); - if !constraints.is_empty() { - self.buf.push_u32_be(constraints.len() as u32); - for cons in constraints { - match *cons { - Constraint::KeyLifetime { seconds } => { - self.buf.push(msg::CONSTRAIN_LIFETIME); - self.buf.push_u32_be(seconds) - } - Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), - Constraint::Extensions { - ref name, - ref details, - } => { - self.buf.push(msg::CONSTRAIN_EXTENSION); - self.buf.extend_ssh_string(name); - self.buf.extend_ssh_string(details); - } - } - } - } - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - Ok(()) - } - - /// Lock the agent, making it refuse to sign until unlocked. - pub async fn lock(&mut self, passphrase: &[u8]) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::LOCK); - self.buf.extend_ssh_string(passphrase); - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - Ok(()) - } - - /// Unlock the agent, allowing it to sign again. - pub async fn unlock(&mut self, passphrase: &[u8]) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::UNLOCK); - self.buf.extend_ssh_string(passphrase); - let len = self.buf.len() - 4; - #[allow(clippy::indexing_slicing)] // static length - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - Ok(()) - } - - /// Ask the agent for a list of the currently registered secret - /// keys. - pub async fn request_identities(&mut self) -> Result, Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::REQUEST_IDENTITIES); - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - - self.read_response().await?; - debug!("identities: {:?}", &self.buf[..]); - let mut keys = Vec::new(); - - #[allow(clippy::indexing_slicing)] // static length - if self.buf[0] == msg::IDENTITIES_ANSWER { - let mut r = self.buf.reader(1); - let n = r.read_u32()?; - for _ in 0..n { - let key = r.read_string()?; - let _ = r.read_string()?; - let mut r = key.reader(0); - let t = r.read_string()?; - debug!("t = {:?}", std::str::from_utf8(t)); - match t { - #[cfg(feature = "openssl")] - b"ssh-rsa" => { - let e = r.read_mpint()?; - let n = r.read_mpint()?; - use openssl::bn::BigNum; - use openssl::pkey::PKey; - use openssl::rsa::Rsa; - keys.push(PublicKey::RSA { - key: key::OpenSSLPKey(PKey::from_rsa(Rsa::from_public_components( - BigNum::from_slice(n)?, - BigNum::from_slice(e)?, - )?)?), - hash: SignatureHash::SHA2_512, - }) - } - b"ssh-ed25519" => keys.push(PublicKey::Ed25519( - ed25519_dalek::VerifyingKey::try_from(r.read_string()?)?, - )), - t => { - info!("Unsupported key type: {:?}", std::str::from_utf8(t)) - } - } - } - } - - Ok(keys) - } - - /// Ask the agent to sign the supplied piece of data. - pub fn sign_request( - mut self, - public: &key::PublicKey, - mut data: CryptoVec, - ) -> impl futures::Future)> { - debug!("sign_request: {:?}", data); - let hash = self.prepare_sign_request(public, &data); - - async move { - if let Err(e) = hash { - return (self, Err(e)); - } - - let resp = self.read_response().await; - debug!("resp = {:?}", &self.buf[..]); - if let Err(e) = resp { - return (self, Err(e)); - } - - #[allow(clippy::indexing_slicing, clippy::unwrap_used)] - // length is checked, hash already checked - if !self.buf.is_empty() && self.buf[0] == msg::SIGN_RESPONSE { - let resp = self.write_signature(hash.unwrap(), &mut data); - if let Err(e) = resp { - return (self, Err(e)); - } - (self, Ok(data)) - } else if self.buf.first() == Some(&msg::FAILURE) { - (self, Err(Error::AgentFailure)) - } else { - debug!("self.buf = {:?}", &self.buf[..]); - (self, Ok(data)) - } - } - } - - fn prepare_sign_request(&mut self, public: &key::PublicKey, data: &[u8]) -> Result { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::SIGN_REQUEST); - key_blob(public, &mut self.buf)?; - self.buf.extend_ssh_string(data); - debug!("public = {:?}", public); - let hash = match public { - #[cfg(feature = "openssl")] - PublicKey::RSA { hash, .. } => match hash { - SignatureHash::SHA2_256 => 2, - SignatureHash::SHA2_512 => 4, - SignatureHash::SHA1 => 0, - }, - _ => 0, - }; - self.buf.push_u32_be(hash); - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - Ok(hash) - } - - fn write_signature(&self, hash: u32, data: &mut CryptoVec) -> Result<(), Error> { - let mut r = self.buf.reader(1); - let mut resp = r.read_string()?.reader(0); - let t = resp.read_string()?; - if (hash == 2 && t == b"rsa-sha2-256") || (hash == 4 && t == b"rsa-sha2-512") || hash == 0 { - let sig = resp.read_string()?; - data.push_u32_be((t.len() + sig.len() + 8) as u32); - data.extend_ssh_string(t); - data.extend_ssh_string(sig); - } - Ok(()) - } - - /// Ask the agent to sign the supplied piece of data. - pub fn sign_request_base64( - mut self, - public: &key::PublicKey, - data: &[u8], - ) -> impl futures::Future)> { - debug!("sign_request: {:?}", data); - let r = self.prepare_sign_request(public, data); - async move { - if let Err(e) = r { - return (self, Err(e)); - } - - let resp = self.read_response().await; - if let Err(e) = resp { - return (self, Err(e)); - } - - #[allow(clippy::indexing_slicing)] // length is checked - if !self.buf.is_empty() && self.buf[0] == msg::SIGN_RESPONSE { - let base64 = data_encoding::BASE64_NOPAD.encode(&self.buf[1..]); - (self, Ok(base64)) - } else { - (self, Ok(String::new())) - } - } - } - - /// Ask the agent to sign the supplied piece of data, and return a `Signature`. - pub fn sign_request_signature( - mut self, - public: &key::PublicKey, - data: &[u8], - ) -> impl futures::Future)> { - debug!("sign_request: {:?}", data); - - let r = self.prepare_sign_request(public, data); - - async move { - if let Err(e) = r { - return (self, Err(e)); - } - - if let Err(e) = self.read_response().await { - return (self, Err(e)); - } - - #[allow(clippy::indexing_slicing)] // length is checked - if !self.buf.is_empty() && self.buf[0] == msg::SIGN_RESPONSE { - let as_sig = |buf: &CryptoVec| -> Result { - let mut r = buf.reader(1); - let mut resp = r.read_string()?.reader(0); - let typ = resp.read_string()?; - let sig = resp.read_string()?; - use crate::signature::Signature; - match typ { - b"ssh-rsa" => Ok(Signature::RSA { - bytes: sig.to_vec(), - hash: SignatureHash::SHA1, - }), - b"rsa-sha2-256" => Ok(Signature::RSA { - bytes: sig.to_vec(), - hash: SignatureHash::SHA2_256, - }), - b"rsa-sha2-512" => Ok(Signature::RSA { - bytes: sig.to_vec(), - hash: SignatureHash::SHA2_512, - }), - b"ssh-ed25519" => { - let mut sig_bytes = [0; 64]; - sig_bytes.clone_from_slice(sig); - Ok(Signature::Ed25519(crate::signature::SignatureBytes( - sig_bytes, - ))) - } - _ => Err(Error::UnknownSignatureType { - sig_type: std::str::from_utf8(typ).unwrap_or("").to_string(), - }), - } - }; - let sig = as_sig(&self.buf); - (self, sig) - } else { - (self, Err(Error::AgentProtocolError)) - } - } - } - - /// Ask the agent to remove a key from its memory. - pub async fn remove_identity(&mut self, public: &key::PublicKey) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::REMOVE_IDENTITY); - key_blob(public, &mut self.buf)?; - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - Ok(()) - } - - /// Ask the agent to remove a smartcard from its memory. - pub async fn remove_smartcard_key(&mut self, id: &str, pin: &[u8]) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::REMOVE_SMARTCARD_KEY); - self.buf.extend_ssh_string(id.as_bytes()); - self.buf.extend_ssh_string(pin); - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - Ok(()) - } - - /// Ask the agent to forget all known keys. - pub async fn remove_all_identities(&mut self) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::REMOVE_ALL_IDENTITIES); - BigEndian::write_u32(&mut self.buf[..], 5); - self.read_response().await?; - Ok(()) - } - - /// Send a custom message to the agent. - pub async fn extension(&mut self, typ: &[u8], ext: &[u8]) -> Result<(), Error> { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::EXTENSION); - self.buf.extend_ssh_string(typ); - self.buf.extend_ssh_string(ext); - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - Ok(()) - } - - /// Ask the agent what extensions about supported extensions. - pub async fn query_extension(&mut self, typ: &[u8], mut ext: CryptoVec) -> Result { - self.buf.clear(); - self.buf.resize(4); - self.buf.push(msg::EXTENSION); - self.buf.extend_ssh_string(typ); - let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); - self.read_response().await?; - - let mut r = self.buf.reader(1); - ext.extend(r.read_string()?); - - #[allow(clippy::indexing_slicing)] // length is checked - Ok(!self.buf.is_empty() && self.buf[0] == msg::SUCCESS) - } -} - -fn key_blob(public: &key::PublicKey, buf: &mut CryptoVec) -> Result<(), Error> { - match *public { - #[cfg(feature = "openssl")] - PublicKey::RSA { ref key, .. } => { - buf.extend(&[0, 0, 0, 0]); - let len0 = buf.len(); - buf.extend_ssh_string(b"ssh-rsa"); - let rsa = key.0.rsa()?; - buf.extend_ssh_mpint(&rsa.e().to_vec()); - buf.extend_ssh_mpint(&rsa.n().to_vec()); - let len1 = buf.len(); - #[allow(clippy::indexing_slicing)] // length is known - BigEndian::write_u32(&mut buf[5..], (len1 - len0) as u32); - } - PublicKey::Ed25519(ref p) => { - buf.extend(&[0, 0, 0, 0]); - let len0 = buf.len(); - buf.extend_ssh_string(b"ssh-ed25519"); - buf.extend_ssh_string(p.as_bytes()); - let len1 = buf.len(); - #[allow(clippy::indexing_slicing)] // length is known - BigEndian::write_u32(&mut buf[5..], (len1 - len0) as u32); - } - } - Ok(()) -} diff --git a/russh-keys/src/encoding.rs b/russh-keys/src/encoding.rs deleted file mode 100644 index 0f64f724..00000000 --- a/russh-keys/src/encoding.rs +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2016 Pierre-Étienne Meunier -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use russh_cryptovec::CryptoVec; - -use crate::Error; - -#[doc(hidden)] -pub trait Bytes { - fn bytes(&self) -> &[u8]; -} - -impl> Bytes for A { - fn bytes(&self) -> &[u8] { - self.as_ref().as_bytes() - } -} - -/// Encode in the SSH format. -pub trait Encoding { - /// Push an SSH-encoded string to `self`. - fn extend_ssh_string(&mut self, s: &[u8]); - /// Push an SSH-encoded blank string of length `s` to `self`. - fn extend_ssh_string_blank(&mut self, s: usize) -> &mut [u8]; - /// Push an SSH-encoded multiple-precision integer. - fn extend_ssh_mpint(&mut self, s: &[u8]); - /// Push an SSH-encoded list. - fn extend_list>(&mut self, list: I); - /// Push an SSH-encoded empty list. - fn write_empty_list(&mut self); -} - -/// Encoding length of the given mpint. -#[allow(clippy::indexing_slicing)] -pub fn mpint_len(s: &[u8]) -> usize { - let mut i = 0; - while i < s.len() && s[i] == 0 { - i += 1 - } - (if s[i] & 0x80 != 0 { 5 } else { 4 }) + s.len() - i -} - -impl Encoding for Vec { - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - fn extend_ssh_string(&mut self, s: &[u8]) { - self.write_u32::(s.len() as u32).unwrap(); - self.extend(s); - } - - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] { - self.write_u32::(len as u32).unwrap(); - let current = self.len(); - self.resize(current + len, 0u8); - #[allow(clippy::indexing_slicing)] // length is known - &mut self[current..] - } - - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - #[allow(clippy::indexing_slicing)] // length is known - fn extend_ssh_mpint(&mut self, s: &[u8]) { - // Skip initial 0s. - let mut i = 0; - while i < s.len() && s[i] == 0 { - i += 1 - } - // If the first non-zero is >= 128, write its length (u32, BE), followed by 0. - if s[i] & 0x80 != 0 { - self.write_u32::((s.len() - i + 1) as u32) - .unwrap(); - self.push(0) - } else { - self.write_u32::((s.len() - i) as u32).unwrap(); - } - self.extend(&s[i..]); - } - - #[allow(clippy::indexing_slicing)] // length is known - fn extend_list>(&mut self, list: I) { - let len0 = self.len(); - self.extend([0, 0, 0, 0]); - let mut first = true; - for i in list { - if !first { - self.push(b',') - } else { - first = false; - } - self.extend(i.bytes()) - } - let len = (self.len() - len0 - 4) as u32; - - BigEndian::write_u32(&mut self[len0..], len); - } - - fn write_empty_list(&mut self) { - self.extend([0, 0, 0, 0]); - } -} - -impl Encoding for CryptoVec { - fn extend_ssh_string(&mut self, s: &[u8]) { - self.push_u32_be(s.len() as u32); - self.extend(s); - } - - #[allow(clippy::indexing_slicing)] // length is known - fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] { - self.push_u32_be(len as u32); - let current = self.len(); - self.resize(current + len); - &mut self[current..] - } - - #[allow(clippy::indexing_slicing)] // length is known - fn extend_ssh_mpint(&mut self, s: &[u8]) { - // Skip initial 0s. - let mut i = 0; - while i < s.len() && s[i] == 0 { - i += 1 - } - // If the first non-zero is >= 128, write its length (u32, BE), followed by 0. - if s[i] & 0x80 != 0 { - self.push_u32_be((s.len() - i + 1) as u32); - self.push(0) - } else { - self.push_u32_be((s.len() - i) as u32); - } - self.extend(&s[i..]); - } - - fn extend_list>(&mut self, list: I) { - let len0 = self.len(); - self.extend(&[0, 0, 0, 0]); - let mut first = true; - for i in list { - if !first { - self.push(b',') - } else { - first = false; - } - self.extend(i.bytes()) - } - let len = (self.len() - len0 - 4) as u32; - - #[allow(clippy::indexing_slicing)] // length is known - BigEndian::write_u32(&mut self[len0..], len); - } - - fn write_empty_list(&mut self) { - self.extend(&[0, 0, 0, 0]); - } -} - -/// A cursor-like trait to read SSH-encoded things. -pub trait Reader { - /// Create an SSH reader for `self`. - fn reader(&self, starting_at: usize) -> Position; -} - -impl Reader for CryptoVec { - fn reader(&self, starting_at: usize) -> Position { - Position { - s: self, - position: starting_at, - } - } -} - -impl Reader for [u8] { - fn reader(&self, starting_at: usize) -> Position { - Position { - s: self, - position: starting_at, - } - } -} - -/// A cursor-like type to read SSH-encoded values. -#[derive(Debug)] -pub struct Position<'a> { - s: &'a [u8], - #[doc(hidden)] - pub position: usize, -} -impl<'a> Position<'a> { - /// Read one string from this reader. - pub fn read_string(&mut self) -> Result<&'a [u8], Error> { - let len = self.read_u32()? as usize; - if self.position + len <= self.s.len() { - #[allow(clippy::indexing_slicing)] // length is known - let result = &self.s[self.position..(self.position + len)]; - self.position += len; - Ok(result) - } else { - Err(Error::IndexOutOfBounds) - } - } - /// Read a `u32` from this reader. - pub fn read_u32(&mut self) -> Result { - if self.position + 4 <= self.s.len() { - #[allow(clippy::indexing_slicing)] // length is known - let u = BigEndian::read_u32(&self.s[self.position..]); - self.position += 4; - Ok(u) - } else { - Err(Error::IndexOutOfBounds) - } - } - /// Read one byte from this reader. - pub fn read_byte(&mut self) -> Result { - if self.position < self.s.len() { - #[allow(clippy::indexing_slicing)] // length is known - let u = self.s[self.position]; - self.position += 1; - Ok(u) - } else { - Err(Error::IndexOutOfBounds) - } - } - - /// Read one byte from this reader. - pub fn read_mpint(&mut self) -> Result<&'a [u8], Error> { - let len = self.read_u32()? as usize; - if self.position + len <= self.s.len() { - #[allow(clippy::indexing_slicing)] // length was checked - let result = &self.s[self.position..(self.position + len)]; - self.position += len; - Ok(result) - } else { - Err(Error::IndexOutOfBounds) - } - } -} diff --git a/russh-keys/src/format/openssh.rs b/russh-keys/src/format/openssh.rs deleted file mode 100644 index 44821fb8..00000000 --- a/russh-keys/src/format/openssh.rs +++ /dev/null @@ -1,169 +0,0 @@ -use std::convert::TryFrom; - -use aes::cipher::block_padding::NoPadding; -use aes::cipher::{BlockDecryptMut, KeyIvInit, StreamCipher}; -use bcrypt_pbkdf; -use ctr::Ctr64BE; -#[cfg(feature = "openssl")] -use openssl::bn::BigNum; - -use crate::encoding::Reader; -use crate::{key, Error, KEYTYPE_ED25519, KEYTYPE_RSA}; - -/// Decode a secret key given in the OpenSSH format, deciphering it if -/// needed using the supplied password. -pub fn decode_openssh(secret: &[u8], password: Option<&str>) -> Result { - if matches!(secret.get(0..15), Some(b"openssh-key-v1\0")) { - let mut position = secret.reader(15); - - let ciphername = position.read_string()?; - let kdfname = position.read_string()?; - let kdfoptions = position.read_string()?; - - let nkeys = position.read_u32()?; - - // Read all public keys - for _ in 0..nkeys { - position.read_string()?; - } - - // Read all secret keys - let secret_ = position.read_string()?; - let secret = decrypt_secret_key(ciphername, kdfname, kdfoptions, password, secret_)?; - let mut position = secret.reader(0); - let _check0 = position.read_u32()?; - let _check1 = position.read_u32()?; - #[allow(clippy::never_loop)] - for _ in 0..nkeys { - // TODO check: never really loops beyond the first key - let key_type = position.read_string()?; - if key_type == KEYTYPE_ED25519 { - let pubkey = position.read_string()?; - let seckey = position.read_string()?; - let _comment = position.read_string()?; - if Some(pubkey) != seckey.get(32..) { - return Err(Error::KeyIsCorrupt); - } - let secret = ed25519_dalek::SigningKey::try_from( - seckey.get(..32).ok_or(Error::KeyIsCorrupt)?, - )?; - return Ok(key::KeyPair::Ed25519(secret)); - } else if key_type == KEYTYPE_RSA && cfg!(feature = "openssl") { - #[cfg(feature = "openssl")] - { - let n = BigNum::from_slice(position.read_string()?)?; - let e = BigNum::from_slice(position.read_string()?)?; - let d = BigNum::from_slice(position.read_string()?)?; - let iqmp = BigNum::from_slice(position.read_string()?)?; - let p = BigNum::from_slice(position.read_string()?)?; - let q = BigNum::from_slice(position.read_string()?)?; - - let mut ctx = openssl::bn::BigNumContext::new()?; - let un = openssl::bn::BigNum::from_u32(1)?; - let mut p1 = openssl::bn::BigNum::new()?; - let mut q1 = openssl::bn::BigNum::new()?; - p1.checked_sub(&p, &un)?; - q1.checked_sub(&q, &un)?; - let mut dmp1 = openssl::bn::BigNum::new()?; // d mod p-1 - dmp1.checked_rem(&d, &p1, &mut ctx)?; - let mut dmq1 = openssl::bn::BigNum::new()?; // d mod q-1 - dmq1.checked_rem(&d, &q1, &mut ctx)?; - - let key = openssl::rsa::RsaPrivateKeyBuilder::new(n, e, d)? - .set_factors(p, q)? - .set_crt_params(dmp1, dmq1, iqmp)? - .build(); - key.check_key()?; - return Ok(key::KeyPair::RSA { - key, - hash: key::SignatureHash::SHA2_512, - }); - } - } else { - return Err(Error::UnsupportedKeyType { - key_type_string: String::from_utf8(key_type.to_vec()) - .unwrap_or_else(|_| format!("{key_type:?}")), - key_type_raw: key_type.to_vec(), - }); - } - } - Err(Error::CouldNotReadKey) - } else { - Err(Error::CouldNotReadKey) - } -} - -use aes::*; - -fn decrypt_secret_key( - ciphername: &[u8], - kdfname: &[u8], - kdfoptions: &[u8], - password: Option<&str>, - secret_key: &[u8], -) -> Result, Error> { - if kdfname == b"none" { - if password.is_none() { - Ok(secret_key.to_vec()) - } else { - Err(Error::CouldNotReadKey) - } - } else if let Some(password) = password { - let mut key = [0; 48]; - let n = match ciphername { - b"aes128-cbc" | b"aes128-ctr" => 32, - b"aes256-cbc" | b"aes256-ctr" => 48, - _ => return Err(Error::CouldNotReadKey), - }; - match kdfname { - b"bcrypt" => { - let mut kdfopts = kdfoptions.reader(0); - let salt = kdfopts.read_string()?; - let rounds = kdfopts.read_u32()?; - #[allow(clippy::unwrap_used)] // parameters are static - #[allow(clippy::indexing_slicing)] // output length is static - match bcrypt_pbkdf::bcrypt_pbkdf(password, salt, rounds, &mut key[..n]) { - Err(bcrypt_pbkdf::Error::InvalidParamLen) => return Err(Error::KeyIsEncrypted), - e => e.unwrap(), - } - } - _kdfname => { - return Err(Error::CouldNotReadKey); - } - }; - let (key, iv) = key.split_at(n - 16); - - let mut dec = secret_key.to_vec(); - dec.resize(dec.len() + 32, 0u8); - match ciphername { - b"aes128-cbc" => { - #[allow(clippy::unwrap_used)] // parameters are static - let cipher = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let n = cipher.decrypt_padded_mut::(&mut dec)?.len(); - dec.truncate(n) - } - b"aes256-cbc" => { - #[allow(clippy::unwrap_used)] // parameters are static - let cipher = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let n = cipher.decrypt_padded_mut::(&mut dec)?.len(); - dec.truncate(n) - } - b"aes128-ctr" => { - #[allow(clippy::unwrap_used)] // parameters are static - let mut cipher = Ctr64BE::::new_from_slices(key, iv).unwrap(); - cipher.apply_keystream(&mut dec); - dec.truncate(secret_key.len()) - } - b"aes256-ctr" => { - #[allow(clippy::unwrap_used)] // parameters are static - let mut cipher = Ctr64BE::::new_from_slices(key, iv).unwrap(); - cipher.apply_keystream(&mut dec); - dec.truncate(secret_key.len()) - } - _ => {} - } - Ok(dec) - } else { - Err(Error::KeyIsEncrypted) - } -} diff --git a/russh-keys/src/format/pkcs8.rs b/russh-keys/src/format/pkcs8.rs deleted file mode 100644 index a9cbb96d..00000000 --- a/russh-keys/src/format/pkcs8.rs +++ /dev/null @@ -1,439 +0,0 @@ -use std::borrow::Cow; - -use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit}; -use bit_vec::BitVec; -use block_padding::{NoPadding, Pkcs7}; -#[cfg(feature = "openssl")] -use openssl::pkey::Private; -#[cfg(feature = "openssl")] -use openssl::rsa::Rsa; -#[cfg(test)] -use rand_core::OsRng; -use std::convert::TryFrom; -use yasna::BERReaderSeq; -use {std, yasna}; - -use super::Encryption; -#[cfg(feature = "openssl")] -use crate::key::SignatureHash; -use crate::{key, Error}; - -const PBES2: &[u64] = &[1, 2, 840, 113549, 1, 5, 13]; -const PBKDF2: &[u64] = &[1, 2, 840, 113549, 1, 5, 12]; -const HMAC_SHA256: &[u64] = &[1, 2, 840, 113549, 2, 9]; -const AES256CBC: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 1, 42]; -const ED25519: &[u64] = &[1, 3, 101, 112]; -#[cfg(feature = "openssl")] -const RSA: &[u64] = &[1, 2, 840, 113549, 1, 1, 1]; - -/// Decode a PKCS#8-encoded private key. -pub fn decode_pkcs8(ciphertext: &[u8], password: Option<&[u8]>) -> Result { - let secret = if let Some(pass) = password { - Cow::Owned(yasna::parse_der(ciphertext, |reader| { - reader.read_sequence(|reader| { - // Encryption parameters - let parameters = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == PBES2 { - asn1_read_pbes2(reader) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - // Ciphertext - let ciphertext = reader.next().read_bytes()?; - Ok(parameters.map(|p| p.decrypt(pass, &ciphertext))) - }) - })???) - } else { - Cow::Borrowed(ciphertext) - }; - yasna::parse_der(&secret, |reader| { - reader.read_sequence(|reader| { - let version = reader.next().read_u64()?; - if version == 0 { - Ok(read_key_v0(reader)) - } else if version == 1 { - Ok(read_key_v1(reader)) - } else { - Ok(Err(Error::CouldNotReadKey)) - } - }) - })? -} - -fn asn1_read_pbes2( - reader: &mut yasna::BERReaderSeq, -) -> Result, yasna::ASN1Error> { - reader.next().read_sequence(|reader| { - // PBES2 has two components. - // 1. Key generation algorithm - let keygen = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == PBKDF2 { - asn1_read_pbkdf2(reader) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - // 2. Encryption algorithm. - let algorithm = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == AES256CBC { - asn1_read_aes256cbc(reader) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - Ok(keygen.and_then(|keygen| algorithm.map(|algo| Algorithms::Pbes2(keygen, algo)))) - }) -} - -fn asn1_read_pbkdf2( - reader: &mut yasna::BERReaderSeq, -) -> Result, yasna::ASN1Error> { - reader.next().read_sequence(|reader| { - let salt = reader.next().read_bytes()?; - let rounds = reader.next().read_u64()?; - let digest = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - if oid.components().as_slice() == HMAC_SHA256 { - reader.next().read_null()?; - Ok(Ok(())) - } else { - Ok(Err(Error::UnknownAlgorithm(oid))) - } - })?; - Ok(digest.map(|()| KeyDerivation::Pbkdf2 { salt, rounds })) - }) -} - -fn asn1_read_aes256cbc( - reader: &mut yasna::BERReaderSeq, -) -> Result, yasna::ASN1Error> { - let iv = reader.next().read_bytes()?; - let mut i = [0; 16]; - i.clone_from_slice(&iv); - Ok(Ok(Encryption::Aes256Cbc(i))) -} - -fn write_key_v1(writer: &mut yasna::DERWriterSeq, secret: &ed25519_dalek::SigningKey) { - let public = ed25519_dalek::VerifyingKey::from(secret); - writer.next().write_u32(1); - // write OID - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(ED25519)); - }); - let seed = yasna::construct_der(|writer| { - writer.write_bytes( - [secret.to_bytes().as_slice(), public.as_bytes().as_slice()] - .concat() - .as_slice(), - ) - }); - writer.next().write_bytes(&seed); - writer - .next() - .write_tagged(yasna::Tag::context(1), |writer| { - writer.write_bitvec(&BitVec::from_bytes(public.as_bytes())) - }) -} - -fn read_key_v1(reader: &mut BERReaderSeq) -> Result { - let oid = reader - .next() - .read_sequence(|reader| reader.next().read_oid())?; - if oid.components().as_slice() == ED25519 { - use ed25519_dalek::SigningKey; - let secret = { - let s = yasna::parse_der(&reader.next().read_bytes()?, |reader| reader.read_bytes())?; - - s.get(..ed25519_dalek::SECRET_KEY_LENGTH) - .ok_or(Error::KeyIsCorrupt) - .and_then(|s| SigningKey::try_from(s).map_err(|_| Error::CouldNotReadKey))? - }; - // Consume the public key - reader - .next() - .read_tagged(yasna::Tag::context(1), |reader| reader.read_bitvec())?; - Ok(key::KeyPair::Ed25519(secret)) - } else { - Err(Error::CouldNotReadKey) - } -} - -#[cfg(feature = "openssl")] -fn write_key_v0(writer: &mut yasna::DERWriterSeq, key: &Rsa) { - writer.next().write_u32(0); - // write OID - writer.next().write_sequence(|writer| { - writer.next().write_oid(&ObjectIdentifier::from_slice(RSA)); - writer.next().write_null() - }); - let bytes = yasna::construct_der(|writer| { - #[allow(clippy::unwrap_used)] // key is known to be private - writer.write_sequence(|writer| { - writer.next().write_u32(0); - use num_bigint::BigUint; - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.n().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.e().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.d().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.p().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.q().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.dmp1().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.dmq1().unwrap().to_vec())); - writer - .next() - .write_biguint(&BigUint::from_bytes_be(&key.iqmp().unwrap().to_vec())); - }) - }); - writer.next().write_bytes(&bytes); -} - -#[cfg(feature = "openssl")] -fn read_key_v0(reader: &mut BERReaderSeq) -> Result { - let oid = reader.next().read_sequence(|reader| { - let oid = reader.next().read_oid()?; - reader.next().read_null()?; - Ok(oid) - })?; - if oid.components().as_slice() == RSA { - let seq = &reader.next().read_bytes()?; - let rsa: Result, Error> = yasna::parse_der(seq, |reader| { - reader.read_sequence(|reader| { - let version = reader.next().read_u32()?; - if version != 0 { - return Ok(Err(Error::CouldNotReadKey)); - } - use openssl::bn::BigNum; - let mut read_key = || -> Result, Error> { - Ok(Rsa::from_private_components( - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - BigNum::from_slice(&reader.next().read_biguint()?.to_bytes_be())?, - )?) - }; - Ok(read_key()) - }) - })?; - Ok(key::KeyPair::RSA { - key: rsa?, - hash: SignatureHash::SHA2_256, - }) - } else { - Err(Error::CouldNotReadKey) - } -} - -#[cfg(not(feature = "openssl"))] -fn read_key_v0(_: &mut BERReaderSeq) -> Result { - Err(Error::CouldNotReadKey) -} - -#[test] -fn test_read_write_pkcs8() { - let secret = ed25519_dalek::SigningKey::generate(&mut OsRng {}); - assert_eq!( - secret.verifying_key().as_bytes(), - ed25519_dalek::VerifyingKey::from(&secret).as_bytes() - ); - let key = key::KeyPair::Ed25519(secret); - let password = b"blabla"; - let ciphertext = encode_pkcs8_encrypted(password, 100, &key).unwrap(); - let key = decode_pkcs8(&ciphertext, Some(password)).unwrap(); - match key { - key::KeyPair::Ed25519 { .. } => println!("Ed25519"), - #[cfg(feature = "openssl")] - key::KeyPair::RSA { .. } => println!("RSA"), - } -} - -use aes::*; -use yasna::models::ObjectIdentifier; - -/// Encode a password-protected PKCS#8-encoded private key. -pub fn encode_pkcs8_encrypted( - pass: &[u8], - rounds: u32, - key: &key::KeyPair, -) -> Result, Error> { - use rand::RngCore; - let mut rng = rand::thread_rng(); - let mut salt = [0; 64]; - rng.fill_bytes(&mut salt); - let mut iv = [0; 16]; - rng.fill_bytes(&mut iv); - let mut dkey = [0; 32]; // AES256-CBC - pbkdf2::pbkdf2::>(pass, &salt, rounds, &mut dkey); - let mut plaintext = encode_pkcs8(key); - - let padding_len = 32 - (plaintext.len() % 32); - plaintext.extend(std::iter::repeat(padding_len as u8).take(padding_len)); - - #[allow(clippy::unwrap_used)] // parameters are static - let c = cbc::Encryptor::::new_from_slices(&dkey, &iv).unwrap(); - let n = plaintext.len(); - let encrypted = c.encrypt_padded_mut::(&mut plaintext, n)?; - - Ok(yasna::construct_der(|writer| { - writer.write_sequence(|writer| { - // Encryption parameters - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(PBES2)); - asn1_write_pbes2(writer.next(), rounds as u64, &salt, &iv) - }); - // Ciphertext - writer.next().write_bytes(encrypted) - }) - })) -} - -/// Encode a Decode a PKCS#8-encoded private key. -pub fn encode_pkcs8(key: &key::KeyPair) -> Vec { - yasna::construct_der(|writer| { - writer.write_sequence(|writer| match *key { - key::KeyPair::Ed25519(ref pair) => write_key_v1(writer, pair), - #[cfg(feature = "openssl")] - key::KeyPair::RSA { ref key, .. } => write_key_v0(writer, key), - }) - }) -} - -fn asn1_write_pbes2(writer: yasna::DERWriter, rounds: u64, salt: &[u8], iv: &[u8]) { - writer.write_sequence(|writer| { - // 1. Key generation algorithm - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(PBKDF2)); - asn1_write_pbkdf2(writer.next(), rounds, salt) - }); - // 2. Encryption algorithm. - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(AES256CBC)); - writer.next().write_bytes(iv) - }); - }) -} - -fn asn1_write_pbkdf2(writer: yasna::DERWriter, rounds: u64, salt: &[u8]) { - writer.write_sequence(|writer| { - writer.next().write_bytes(salt); - writer.next().write_u64(rounds); - writer.next().write_sequence(|writer| { - writer - .next() - .write_oid(&ObjectIdentifier::from_slice(HMAC_SHA256)); - writer.next().write_null() - }) - }) -} - -enum Algorithms { - Pbes2(KeyDerivation, Encryption), -} - -impl Algorithms { - fn decrypt(&self, password: &[u8], cipher: &[u8]) -> Result, Error> { - match *self { - Algorithms::Pbes2(ref der, ref enc) => { - let mut key = enc.key(); - der.derive(password, &mut key)?; - let out = enc.decrypt(&key, cipher)?; - Ok(out) - } - } - } -} - -impl KeyDerivation { - fn derive(&self, password: &[u8], key: &mut [u8]) -> Result<(), Error> { - match *self { - KeyDerivation::Pbkdf2 { ref salt, rounds } => { - pbkdf2::pbkdf2::>(password, salt, rounds as u32, key) - // pbkdf2_hmac(password, salt, rounds as usize, digest, key)? - } - } - Ok(()) - } -} - -#[derive(Debug)] -enum Key { - K128([u8; 16]), - K256([u8; 32]), -} - -impl std::ops::Deref for Key { - type Target = [u8]; - fn deref(&self) -> &[u8] { - match *self { - Key::K128(ref k) => k, - Key::K256(ref k) => k, - } - } -} - -impl std::ops::DerefMut for Key { - fn deref_mut(&mut self) -> &mut [u8] { - match *self { - Key::K128(ref mut k) => k, - Key::K256(ref mut k) => k, - } - } -} - -impl Encryption { - fn key(&self) -> Key { - match *self { - Encryption::Aes128Cbc(_) => Key::K128([0; 16]), - Encryption::Aes256Cbc(_) => Key::K256([0; 32]), - } - } - - fn decrypt(&self, key: &[u8], ciphertext: &[u8]) -> Result, Error> { - match *self { - Encryption::Aes128Cbc(ref iv) => { - #[allow(clippy::unwrap_used)] // parameters are static - let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let mut dec = ciphertext.to_vec(); - Ok(c.decrypt_padded_mut::(&mut dec)?.into()) - } - Encryption::Aes256Cbc(ref iv) => { - #[allow(clippy::unwrap_used)] // parameters are static - let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); - let mut dec = ciphertext.to_vec(); - Ok(c.decrypt_padded_mut::(&mut dec)?.into()) - } - } - } -} - -enum KeyDerivation { - Pbkdf2 { salt: Vec, rounds: u64 }, -} diff --git a/russh-keys/src/key.rs b/russh-keys/src/key.rs deleted file mode 100644 index 98ebda9f..00000000 --- a/russh-keys/src/key.rs +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2016 Pierre-Étienne Meunier -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -use ed25519_dalek::{Signer, Verifier}; -#[cfg(feature = "openssl")] -use openssl::pkey::{Private, Public}; -use rand_core::OsRng; -use russh_cryptovec::CryptoVec; -use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; - -use crate::encoding::{Encoding, Reader}; -pub use crate::signature::*; -use crate::Error; - -#[derive(Debug, PartialEq, Eq, Copy, Clone)] -/// Name of a public key algorithm. -pub struct Name(pub &'static str); - -impl AsRef for Name { - fn as_ref(&self) -> &str { - self.0 - } -} - -/// The name of the Ed25519 algorithm for SSH. -pub const ED25519: Name = Name("ssh-ed25519"); -/// The name of the ssh-sha2-512 algorithm for SSH. -pub const RSA_SHA2_512: Name = Name("rsa-sha2-512"); -/// The name of the ssh-sha2-256 algorithm for SSH. -pub const RSA_SHA2_256: Name = Name("rsa-sha2-256"); - -pub const NONE: Name = Name("none"); - -pub const SSH_RSA: Name = Name("ssh-rsa"); - -impl Name { - /// Base name of the private key file for a key name. - pub fn identity_file(&self) -> &'static str { - match *self { - ED25519 => "id_ed25519", - RSA_SHA2_512 => "id_rsa", - RSA_SHA2_256 => "id_rsa", - _ => unreachable!(), - } - } -} - -#[doc(hidden)] -pub trait Verify { - fn verify_client_auth(&self, buffer: &[u8], sig: &[u8]) -> bool; - fn verify_server_auth(&self, buffer: &[u8], sig: &[u8]) -> bool; -} - -/// The hash function used for signing with RSA keys. -#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash, Serialize, Deserialize)] -#[allow(non_camel_case_types)] -pub enum SignatureHash { - /// SHA2, 256 bits. - SHA2_256, - /// SHA2, 512 bits. - SHA2_512, - /// SHA1 - SHA1, -} - -impl SignatureHash { - pub fn name(&self) -> Name { - match *self { - SignatureHash::SHA2_256 => RSA_SHA2_256, - SignatureHash::SHA2_512 => RSA_SHA2_512, - SignatureHash::SHA1 => SSH_RSA, - } - } - - #[cfg(feature = "openssl")] - fn message_digest(&self) -> openssl::hash::MessageDigest { - use openssl::hash::MessageDigest; - match *self { - SignatureHash::SHA2_256 => MessageDigest::sha256(), - SignatureHash::SHA2_512 => MessageDigest::sha512(), - SignatureHash::SHA1 => MessageDigest::sha1(), - } - } - - pub fn from_rsa_hostkey_algo(algo: &[u8]) -> Option { - if algo == b"rsa-sha2-256" { - Some(Self::SHA2_256) - } else if algo == b"rsa-sha2-512" { - Some(Self::SHA2_512) - } else { - Some(Self::SHA1) - } - } -} - -/// Public key -#[derive(Eq, Debug, Clone)] -pub enum PublicKey { - #[doc(hidden)] - Ed25519(ed25519_dalek::VerifyingKey), - #[doc(hidden)] - #[cfg(feature = "openssl")] - RSA { - key: OpenSSLPKey, - hash: SignatureHash, - }, -} - -impl PartialEq for PublicKey { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - #[cfg(feature = "openssl")] - (Self::RSA { key: a, .. }, Self::RSA { key: b, .. }) => a == b, - (Self::Ed25519(a), Self::Ed25519(b)) => a == b, - #[cfg(feature = "openssl")] - _ => false, - } - } -} - -/// A public key from OpenSSL. -#[cfg(feature = "openssl")] -#[derive(Clone)] -pub struct OpenSSLPKey(pub openssl::pkey::PKey); - -#[cfg(feature = "openssl")] -use std::cmp::{Eq, PartialEq}; - -#[cfg(feature = "openssl")] -impl PartialEq for OpenSSLPKey { - fn eq(&self, b: &OpenSSLPKey) -> bool { - self.0.public_eq(&b.0) - } -} -#[cfg(feature = "openssl")] -impl Eq for OpenSSLPKey {} -#[cfg(feature = "openssl")] -impl std::fmt::Debug for OpenSSLPKey { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "OpenSSLPKey {{ (hidden) }}") - } -} - -impl PublicKey { - /// Parse a public key in SSH format. - pub fn parse(algo: &[u8], pubkey: &[u8]) -> Result { - match algo { - b"ssh-ed25519" => { - let mut p = pubkey.reader(0); - let key_algo = p.read_string()?; - let key_bytes = p.read_string()?; - if key_algo != b"ssh-ed25519" { - return Err(Error::CouldNotReadKey); - } - let Ok(key_bytes) = <&[u8; ed25519_dalek::PUBLIC_KEY_LENGTH]>::try_from(key_bytes) else { - return Err(Error::CouldNotReadKey); - }; - ed25519_dalek::VerifyingKey::from_bytes(key_bytes) - .map(PublicKey::Ed25519) - .map_err(Error::from) - } - b"ssh-rsa" | b"rsa-sha2-256" | b"rsa-sha2-512" if cfg!(feature = "openssl") => { - #[cfg(feature = "openssl")] - { - use log::debug; - let mut p = pubkey.reader(0); - let key_algo = p.read_string()?; - debug!("{:?}", std::str::from_utf8(key_algo)); - if key_algo != b"ssh-rsa" - && key_algo != b"rsa-sha2-256" - && key_algo != b"rsa-sha2-512" - { - return Err(Error::CouldNotReadKey); - } - let key_e = p.read_string()?; - let key_n = p.read_string()?; - use openssl::bn::BigNum; - use openssl::pkey::PKey; - use openssl::rsa::Rsa; - Ok(PublicKey::RSA { - key: OpenSSLPKey(PKey::from_rsa(Rsa::from_public_components( - BigNum::from_slice(key_n)?, - BigNum::from_slice(key_e)?, - )?)?), - hash: SignatureHash::from_rsa_hostkey_algo(algo) - .unwrap_or(SignatureHash::SHA1), - }) - } - #[cfg(not(feature = "openssl"))] - { - unreachable!() - } - } - _ => Err(Error::CouldNotReadKey), - } - } - - /// Algorithm name for that key. - pub fn name(&self) -> &'static str { - match *self { - PublicKey::Ed25519(_) => ED25519.0, - #[cfg(feature = "openssl")] - PublicKey::RSA { ref hash, .. } => hash.name().0, - } - } - - /// Verify a signature. - pub fn verify_detached(&self, buffer: &[u8], sig: &[u8]) -> bool { - match self { - PublicKey::Ed25519(ref public) => { - let Ok(sig) = ed25519_dalek::ed25519::SignatureBytes::try_from(sig) else { - return false; - }; - let sig = ed25519_dalek::Signature::from_bytes(&sig); - public.verify(buffer, &sig).is_ok() - } - - #[cfg(feature = "openssl")] - PublicKey::RSA { ref key, ref hash } => { - use openssl::sign::*; - let verify = || { - let mut verifier = Verifier::new(hash.message_digest(), &key.0)?; - verifier.update(buffer)?; - verifier.verify(sig) - }; - verify().unwrap_or(false) - } - } - } - - /// Compute the key fingerprint, hashed with sha2-256. - pub fn fingerprint(&self) -> String { - use super::PublicKeyBase64; - let key = self.public_key_bytes(); - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(&key[..]); - data_encoding::BASE64_NOPAD.encode(&hasher.finalize()) - } - - #[cfg(feature = "openssl")] - pub fn set_algorithm(&mut self, algorithm: &[u8]) { - if let PublicKey::RSA { ref mut hash, .. } = self { - if algorithm == b"rsa-sha2-512" { - *hash = SignatureHash::SHA2_512 - } else if algorithm == b"rsa-sha2-256" { - *hash = SignatureHash::SHA2_256 - } else if algorithm == b"ssh-rsa" { - *hash = SignatureHash::SHA1 - } - } - } - - #[cfg(not(feature = "openssl"))] - pub fn set_algorithm(&mut self, _: &[u8]) {} -} - -impl Verify for PublicKey { - fn verify_client_auth(&self, buffer: &[u8], sig: &[u8]) -> bool { - self.verify_detached(buffer, sig) - } - fn verify_server_auth(&self, buffer: &[u8], sig: &[u8]) -> bool { - self.verify_detached(buffer, sig) - } -} - -/// Public key exchange algorithms. -#[allow(clippy::large_enum_variant)] -pub enum KeyPair { - Ed25519(ed25519_dalek::SigningKey), - #[cfg(feature = "openssl")] - RSA { - key: openssl::rsa::Rsa, - hash: SignatureHash, - }, -} - -impl Clone for KeyPair { - fn clone(&self) -> Self { - match self { - #[allow(clippy::expect_used)] - Self::Ed25519(kp) => { - Self::Ed25519(ed25519_dalek::SigningKey::from_bytes(&kp.to_bytes())) - } - #[cfg(feature = "openssl")] - Self::RSA { key, hash } => Self::RSA { - key: key.clone(), - hash: *hash, - }, - } - } -} - -impl std::fmt::Debug for KeyPair { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match *self { - KeyPair::Ed25519(ref key) => write!( - f, - "Ed25519 {{ public: {:?}, secret: (hidden) }}", - key.verifying_key().as_bytes() - ), - #[cfg(feature = "openssl")] - KeyPair::RSA { .. } => write!(f, "RSA {{ (hidden) }}"), - } - } -} - -impl<'b> crate::encoding::Bytes for &'b KeyPair { - fn bytes(&self) -> &[u8] { - self.name().as_bytes() - } -} - -impl KeyPair { - /// Copy the public key of this algorithm. - pub fn clone_public_key(&self) -> Result { - Ok(match self { - KeyPair::Ed25519(ref key) => PublicKey::Ed25519(key.verifying_key()), - #[cfg(feature = "openssl")] - KeyPair::RSA { ref key, ref hash } => { - use openssl::pkey::PKey; - use openssl::rsa::Rsa; - let key = Rsa::from_public_components(key.n().to_owned()?, key.e().to_owned()?)?; - PublicKey::RSA { - key: OpenSSLPKey(PKey::from_rsa(key)?), - hash: *hash, - } - } - }) - } - - /// Name of this key algorithm. - pub fn name(&self) -> &'static str { - match *self { - KeyPair::Ed25519(_) => ED25519.0, - #[cfg(feature = "openssl")] - KeyPair::RSA { ref hash, .. } => hash.name().0, - } - } - - /// Generate a key pair. - pub fn generate_ed25519() -> Option { - let keypair = ed25519_dalek::SigningKey::generate(&mut OsRng {}); - assert_eq!( - keypair.verifying_key().as_bytes(), - ed25519_dalek::VerifyingKey::from(&keypair).as_bytes() - ); - Some(KeyPair::Ed25519(keypair)) - } - - #[cfg(feature = "openssl")] - pub fn generate_rsa(bits: usize, hash: SignatureHash) -> Option { - let key = openssl::rsa::Rsa::generate(bits as u32).ok()?; - Some(KeyPair::RSA { key, hash }) - } - - /// Sign a slice using this algorithm. - pub fn sign_detached(&self, to_sign: &[u8]) -> Result { - match self { - #[allow(clippy::unwrap_used)] - KeyPair::Ed25519(ref secret) => Ok(Signature::Ed25519(SignatureBytes( - secret.sign(to_sign).to_bytes(), - ))), - #[cfg(feature = "openssl")] - KeyPair::RSA { ref key, ref hash } => Ok(Signature::RSA { - bytes: rsa_signature(hash, key, to_sign)?, - hash: *hash, - }), - } - } - - #[doc(hidden)] - /// This is used by the server to sign the initial DH kex - /// message. Note: we are not signing the same kind of thing as in - /// the function below, `add_self_signature`. - pub fn add_signature>( - &self, - buffer: &mut CryptoVec, - to_sign: H, - ) -> Result<(), Error> { - match self { - KeyPair::Ed25519(ref secret) => { - let signature = secret.sign(to_sign.as_ref()); - - buffer.push_u32_be((ED25519.0.len() + signature.to_bytes().len() + 8) as u32); - buffer.extend_ssh_string(ED25519.0.as_bytes()); - buffer.extend_ssh_string(signature.to_bytes().as_slice()); - } - #[cfg(feature = "openssl")] - KeyPair::RSA { ref key, ref hash } => { - // https://tools.ietf.org/html/draft-rsa-dsa-sha2-256-02#section-2.2 - let signature = rsa_signature(hash, key, to_sign.as_ref())?; - let name = hash.name(); - buffer.push_u32_be((name.0.len() + signature.len() + 8) as u32); - buffer.extend_ssh_string(name.0.as_bytes()); - buffer.extend_ssh_string(&signature); - } - } - Ok(()) - } - - #[doc(hidden)] - /// This is used by the client for authentication. Note: we are - /// not signing the same kind of thing as in the above function, - /// `add_signature`. - pub fn add_self_signature(&self, buffer: &mut CryptoVec) -> Result<(), Error> { - match self { - KeyPair::Ed25519(ref secret) => { - let signature = secret.sign(buffer); - buffer.push_u32_be((ED25519.0.len() + signature.to_bytes().len() + 8) as u32); - buffer.extend_ssh_string(ED25519.0.as_bytes()); - buffer.extend_ssh_string(signature.to_bytes().as_slice()); - } - #[cfg(feature = "openssl")] - KeyPair::RSA { ref key, ref hash } => { - // https://tools.ietf.org/html/draft-rsa-dsa-sha2-256-02#section-2.2 - let signature = rsa_signature(hash, key, buffer)?; - let name = hash.name(); - buffer.push_u32_be((name.0.len() + signature.len() + 8) as u32); - buffer.extend_ssh_string(name.0.as_bytes()); - buffer.extend_ssh_string(&signature); - } - } - Ok(()) - } - - /// Create a copy of an RSA key with a specified hash algorithm. - #[cfg(feature = "openssl")] - pub fn with_signature_hash(&self, hash: SignatureHash) -> Option { - match self { - KeyPair::Ed25519(_) => None, - #[cfg(feature = "openssl")] - KeyPair::RSA { key, .. } => Some(KeyPair::RSA { - key: key.clone(), - hash, - }), - } - } -} - -#[cfg(feature = "openssl")] -fn rsa_signature( - hash: &SignatureHash, - key: &openssl::rsa::Rsa, - b: &[u8], -) -> Result, Error> { - use openssl::pkey::*; - use openssl::rsa::*; - use openssl::sign::Signer; - let pkey = PKey::from_rsa(Rsa::from_private_components( - key.n().to_owned()?, - key.e().to_owned()?, - key.d().to_owned()?, - key.p().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.q().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.dmp1().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.dmq1().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - key.iqmp().ok_or(Error::KeyIsCorrupt)?.to_owned()?, - )?)?; - let mut signer = Signer::new(hash.message_digest(), &pkey)?; - signer.update(b)?; - Ok(signer.sign_to_vec()?) -} - -/// Parse a public key from a byte slice. -pub fn parse_public_key( - p: &[u8], - #[cfg(feature = "openssl")] prefer_hash: Option, -) -> Result { - let mut pos = p.reader(0); - let t = pos.read_string()?; - if t == b"ssh-ed25519" { - if let Ok(pubkey) = pos.read_string() { - let Ok(pubkey) = <&[u8; ed25519_dalek::PUBLIC_KEY_LENGTH]>::try_from(pubkey) else { - return Err(Error::CouldNotReadKey); - }; - let p = ed25519_dalek::VerifyingKey::from_bytes(pubkey).map_err(Error::from)?; - return Ok(PublicKey::Ed25519(p)); - } - } - if t == b"ssh-rsa" { - #[cfg(feature = "openssl")] - { - let e = pos.read_string()?; - let n = pos.read_string()?; - use openssl::bn::*; - use openssl::pkey::*; - use openssl::rsa::*; - return Ok(PublicKey::RSA { - key: OpenSSLPKey(PKey::from_rsa(Rsa::from_public_components( - BigNum::from_slice(n)?, - BigNum::from_slice(e)?, - )?)?), - hash: prefer_hash.unwrap_or(SignatureHash::SHA2_256), - }); - } - } - Err(Error::CouldNotReadKey) -} diff --git a/russh-keys/src/signature.rs b/russh-keys/src/signature.rs deleted file mode 100644 index 712139c2..00000000 --- a/russh-keys/src/signature.rs +++ /dev/null @@ -1,157 +0,0 @@ -use std::fmt; - -use byteorder::{BigEndian, WriteBytesExt}; -use serde; -use serde::de::{SeqAccess, Visitor}; -use serde::ser::SerializeTuple; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -use crate::key::SignatureHash; -use crate::Error; - -pub struct SignatureBytes(pub [u8; 64]); - -/// The type of a signature, depending on the algorithm used. -#[derive(Serialize, Deserialize, Clone)] -pub enum Signature { - /// An Ed25519 signature - Ed25519(SignatureBytes), - /// An RSA signature - RSA { hash: SignatureHash, bytes: Vec }, -} - -impl Signature { - pub fn to_base64(&self) -> String { - use crate::encoding::Encoding; - let mut bytes_ = Vec::new(); - match self { - Signature::Ed25519(ref bytes) => { - let t = b"ssh-ed25519"; - #[allow(clippy::unwrap_used)] // Vec<>.write_all can't fail - bytes_ - .write_u32::((t.len() + bytes.0.len() + 8) as u32) - .unwrap(); - bytes_.extend_ssh_string(t); - bytes_.extend_ssh_string(&bytes.0[..]); - } - Signature::RSA { - ref hash, - ref bytes, - } => { - let t = match hash { - SignatureHash::SHA2_256 => &b"rsa-sha2-256"[..], - SignatureHash::SHA2_512 => &b"rsa-sha2-512"[..], - SignatureHash::SHA1 => &b"ssh-rsa"[..], - }; - #[allow(clippy::unwrap_used)] // Vec<>.write_all can't fail - bytes_ - .write_u32::((t.len() + bytes.len() + 8) as u32) - .unwrap(); - bytes_.extend_ssh_string(t); - bytes_.extend_ssh_string(&bytes[..]); - } - } - data_encoding::BASE64_NOPAD.encode(&bytes_[..]) - } - - pub fn from_base64(s: &[u8]) -> Result { - let bytes_ = data_encoding::BASE64_NOPAD.decode(s)?; - use crate::encoding::Reader; - let mut r = bytes_.reader(0); - let sig = r.read_string()?; - let mut r = sig.reader(0); - let typ = r.read_string()?; - let bytes = r.read_string()?; - match typ { - b"ssh-ed25519" => { - let mut bytes_ = [0; 64]; - bytes_.clone_from_slice(bytes); - Ok(Signature::Ed25519(SignatureBytes(bytes_))) - } - b"rsa-sha2-256" => Ok(Signature::RSA { - hash: SignatureHash::SHA2_256, - bytes: bytes.to_vec(), - }), - b"rsa-sha2-512" => Ok(Signature::RSA { - hash: SignatureHash::SHA2_512, - bytes: bytes.to_vec(), - }), - b"ssh-rsa" => Ok(Signature::RSA { - hash: SignatureHash::SHA1, - bytes: bytes.to_vec(), - }), - _ => Err(Error::UnknownSignatureType { - sig_type: std::str::from_utf8(typ).unwrap_or("").to_string(), - }), - } - } -} - -impl AsRef<[u8]> for Signature { - fn as_ref(&self) -> &[u8] { - match *self { - Signature::Ed25519(ref signature) => &signature.0, - Signature::RSA { ref bytes, .. } => &bytes[..], - } - } -} - -impl AsRef<[u8]> for SignatureBytes { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl<'de> Deserialize<'de> for SignatureBytes { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct Vis; - impl<'de> Visitor<'de> for Vis { - type Value = SignatureBytes; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("64 bytes") - } - fn visit_seq>(self, mut seq: A) -> Result { - let mut result = [0; 64]; - for x in result.iter_mut() { - if let Some(y) = seq.next_element()? { - *x = y - } else { - return Err(serde::de::Error::invalid_length(64, &self)); - } - } - Ok(SignatureBytes(result)) - } - } - deserializer.deserialize_tuple(64, Vis) - } -} - -impl Serialize for SignatureBytes { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut tup = serializer.serialize_tuple(64)?; - for byte in self.0.iter() { - tup.serialize_element(byte)?; - } - tup.end() - } -} - -impl fmt::Debug for SignatureBytes { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{:?}", &self.0[..]) - } -} - -impl Clone for SignatureBytes { - fn clone(&self) -> Self { - let mut result = SignatureBytes([0; 64]); - result.0.clone_from_slice(&self.0); - result - } -} diff --git a/russh-util/Cargo.toml b/russh-util/Cargo.toml new file mode 100644 index 00000000..2c36a6c2 --- /dev/null +++ b/russh-util/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "russh-util" +version = "0.52.0" +edition = "2021" +rust-version = "1.75" +description = "Runtime abstraction utilities for russh." +documentation = "https://docs.rs/russh-util" +homepage = "https://github.com/warp-tech/russh" +license = "Apache-2.0" +repository = "https://github.com/warp-tech/russh" + +[dependencies] +tokio = { workspace = true, features = ["sync", "macros"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +chrono = "0.4.38" +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4.43" + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { workspace = true, features = ["io-util", "rt-multi-thread", "rt"] } diff --git a/russh-util/src/lib.rs b/russh-util/src/lib.rs new file mode 100644 index 00000000..ba4302eb --- /dev/null +++ b/russh-util/src/lib.rs @@ -0,0 +1,2 @@ +pub mod runtime; +pub mod time; diff --git a/russh-util/src/runtime.rs b/russh-util/src/runtime.rs new file mode 100644 index 00000000..ad6d280a --- /dev/null +++ b/russh-util/src/runtime.rs @@ -0,0 +1,63 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug)] +pub struct JoinError; + +impl std::fmt::Display for JoinError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JoinError") + } +} + +impl std::error::Error for JoinError {} + +pub struct JoinHandle +where + T: Send, +{ + handle: tokio::sync::oneshot::Receiver, +} + +#[cfg(target_arch = "wasm32")] +macro_rules! spawn_impl { + ($fn:expr) => { + wasm_bindgen_futures::spawn_local($fn) + }; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! spawn_impl { + ($fn:expr) => { + tokio::spawn($fn) + }; +} + +pub fn spawn(future: F) -> JoinHandle +where + F: Future + 'static + Send, + T: Send + 'static, +{ + let (sender, receiver) = tokio::sync::oneshot::channel(); + spawn_impl!(async { + let result = future.await; + let _ = sender.send(result); + }); + JoinHandle { handle: receiver } +} + +impl Future for JoinHandle +where + T: Send, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.handle).poll(cx) { + Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), + Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/russh-util/src/time.rs b/russh-util/src/time.rs new file mode 100644 index 00000000..a5e1adc2 --- /dev/null +++ b/russh-util/src/time.rs @@ -0,0 +1,27 @@ +#[cfg(not(target_arch = "wasm32"))] +pub use std::time::Instant; + +#[cfg(target_arch = "wasm32")] +pub use wasm::Instant; + +#[cfg(target_arch = "wasm32")] +mod wasm { + #[derive(Debug, Clone, Copy)] + pub struct Instant { + inner: chrono::DateTime, + } + + impl Instant { + pub fn now() -> Self { + Instant { + inner: chrono::Utc::now(), + } + } + + pub fn duration_since(&self, earlier: Instant) -> std::time::Duration { + (self.inner - earlier.inner) + .to_std() + .expect("Duration is negative") + } + } +} diff --git a/russh/Cargo.toml b/russh/Cargo.toml index 8f792f79..e5b28c13 100644 --- a/russh/Cargo.toml +++ b/russh/Cargo.toml @@ -2,71 +2,135 @@ authors = ["Pierre-Étienne Meunier "] description = "A client and server SSH library." documentation = "https://docs.rs/russh" -edition = "2018" -homepage = "https://pijul.org/russh" +edition = "2021" +homepage = "https://github.com/warp-tech/russh" keywords = ["ssh"] license = "Apache-2.0" name = "russh" readme = "../README.md" repository = "https://github.com/warp-tech/russh" -version = "0.38.0-beta.1" -rust-version = "1.60" +version = "0.53.0-beta.1" +rust-version = "1.75" [features] -default = ["flate2"] -openssl = ["russh-keys/openssl", "dep:openssl"] -vendored-openssl = ["openssl/vendored", "russh-keys/vendored-openssl"] +default = ["flate2", "aws-lc-rs"] +aws-lc-rs = ["dep:aws-lc-rs"] +async-trait = ["dep:async-trait"] +legacy-ed25519-pkcs8-parser = ["yasna"] +# Danger: 3DES cipher is insecure. +des = ["dep:des"] +# Danger: DSA algorithm is insecure. +dsa = ["ssh-key/dsa"] +ring = ["dep:ring"] +_bench = ["dep:criterion"] [dependencies] -aes = "0.8" -aes-gcm = "0.10" -async-trait = "0.1" +aes.workspace = true +async-trait = { workspace = true, optional = true } +aws-lc-rs = { version = "1.13.1", optional = true } bitflags = "2.0" -byteorder = "1.3" -chacha20 = "0.9" -curve25519-dalek = "4.0" -poly1305 = "0.8" +block-padding = { version = "0.3", features = ["std"] } +byteorder.workspace = true +bytes.workspace = true +cbc = { version = "0.1" } ctr = "0.9" -digest = "0.10" -flate2 = { version = "1.0", optional = true } -futures = "0.3" +curve25519-dalek = "4.1.3" +data-encoding = "2.3" +delegate.workspace = true +digest.workspace = true +der = "0.7" +des = { version = "0.8.1", optional = true } +ecdsa = "0.16" +ed25519-dalek = { version = "2.0", features = ["rand_core", "pkcs8"] } +elliptic-curve = { version = "0.13", features = ["ecdh"] } +enum_dispatch = "0.3.13" +flate2 = { version = "1.0.15", optional = true } +futures.workspace = true generic-array = "0.14" -hmac = "0.12" -log = "0.4" -once_cell = "1.13" -openssl = { version = "0.10", optional = true } -rand = "0.8" -russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } -russh-keys = { version = "0.37.1", path = "../russh-keys" } -sha1 = "0.10" -sha2 = "0.10" +getrandom = { version = "0.2.15", features = ["js"] } hex-literal = "0.4" -num-bigint = { version = "0.4", features = ["rand"] } +hmac.workspace = true +inout = { version = "0.1", features = ["std"] } +log.workspace = true +md5 = "0.7" +num-bigint = { version = "0.4.2", features = ["rand"] } +# num-integer = "0.1" +once_cell = "1.13" +p256 = { version = "0.13", features = ["ecdh"] } +p384 = { version = "0.13", features = ["ecdh"] } +p521 = { version = "0.13", features = ["ecdh"] } +pbkdf2 = "0.12" +pkcs1 = "0.7" +pkcs5 = "0.7" +pkcs8 = { version = "0.10", features = ["pkcs5", "encryption"] } +rand_core = { version = "0.6.4", features = ["getrandom", "std"] } +rand.workspace = true +ring = { version = "0.17.14", optional = true } +rsa.workspace = true +russh-cryptovec = { version = "0.52.0", path = "../cryptovec", features = [ + "ssh-encoding", +] } +russh-util = { version = "0.52.0", path = "../russh-util" } +sec1 = { version = "0.7", features = ["pkcs8", "der"] } +sha1.workspace = true +sha2.workspace = true +signature.workspace = true +spki = "0.7" +ssh-encoding.workspace = true +ssh-key.workspace = true subtle = "2.4" -thiserror = "1.0" -tokio = { version = "1.17.0", features = [ - "io-util", - "rt-multi-thread", - "time", - "net", - "sync", - "macros", - "process", +thiserror.workspace = true +tokio = { workspace = true, features = ["io-util", "sync", "time"] } +typenum = "1.17" +yasna = { version = "0.5.0", features = [ + "bit-vec", + "num-bigint", +], optional = true } +zeroize = "1.7" +base64ct = "~1.6" # can be removed in 2024 edition +criterion = { version = "0.3", optional = true, features = ["html_reports"] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { workspace = true, features = [ + "io-util", + "rt-multi-thread", + "time", + "net", ] } -tokio-util = "0.7" +home.workspace = true + +[target.'cfg(windows)'.dependencies] +pageant = { version = "0.0.3", path = "../pageant" } [dev-dependencies] -anyhow = "1.0" -env_logger = "0.10" -tokio = { version = "1.17.0", features = [ - "io-util", - "rt-multi-thread", - "time", - "net", - "sync", - "macros", +anyhow = "1.0.4" +env_logger.workspace = true +clap = { version = "3.2.3", features = ["derive"] } +tokio = { workspace = true, features = [ + "io-std", + "io-util", + "rt-multi-thread", + "time", + "net", + "sync", + "macros", + "process", ] } -russh-sftp = "1.1" +rand = "0.8.5" +shell-escape = "0.1" +tokio-fd = "0.3" +termion = "2" +ratatui = "0.29.0" +tempfile = "3.14.0" + +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +russh-sftp = "2.1.0" +tokio.workspace = true +tokio-stream.workspace = true [package.metadata.docs.rs] -features = ["openssl"] +all-features = true + +[[bench]] +name = "ciphers" +harness = false diff --git a/russh/benches/ciphers.rs b/russh/benches/ciphers.rs new file mode 100755 index 00000000..02a9b0a2 --- /dev/null +++ b/russh/benches/ciphers.rs @@ -0,0 +1,4 @@ +use criterion::{criterion_group, criterion_main}; +use russh::cipher::benchmark::bench; +criterion_group!(benches, bench); +criterion_main!(benches); diff --git a/russh/examples/client.rs b/russh/examples/client.rs deleted file mode 100644 index 6948835f..00000000 --- a/russh/examples/client.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; - -use anyhow::Context; -use async_trait::async_trait; -use russh::*; -use russh_keys::*; - -struct Client {} - -#[async_trait] -impl client::Handler for Client { - type Error = russh::Error; - - async fn check_server_key( - self, - server_public_key: &key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - println!("check_server_key: {:?}", server_public_key); - Ok((self, true)) - } -} - -#[tokio::main] -async fn main() { - env_logger::init(); - let config = russh::client::Config::default(); - let config = Arc::new(config); - let sh = Client {}; - - let mut agent = russh_keys::agent::client::AgentClient::connect_env() - .await - .unwrap(); - let mut identities = agent.request_identities().await.unwrap(); - let mut session = russh::client::connect(config, ("127.0.0.1", 2200), sh) - .await - .unwrap(); - let (_, auth_res) = session - .authenticate_future("pe", identities.pop().unwrap(), agent) - .await; - let auth_res = auth_res.unwrap(); - println!("=== auth: {}", auth_res); - let mut channel = session - .channel_open_direct_tcpip("localhost", 8000, "localhost", 3333) - .await - .unwrap(); - // let mut channel = session.channel_open_session().await.unwrap(); - println!("=== after open channel"); - let data = b"GET /les_affames.mkv HTTP/1.1\nUser-Agent: curl/7.68.0\nAccept: */*\nConnection: close\n\n"; - channel.data(&data[..]).await.unwrap(); - let mut f = std::fs::File::create("les_affames.mkv").unwrap(); - while let Some(msg) = channel.wait().await { - use std::io::Write; - match msg { - russh::ChannelMsg::Data { ref data } => { - f.write_all(data).unwrap(); - } - russh::ChannelMsg::Eof => { - f.flush().unwrap(); - break; - } - _ => {} - } - } - session - .disconnect(Disconnect::ByApplication, "", "English") - .await - .unwrap(); - let res = session.await.context("session await"); - println!("{:#?}", res); -} diff --git a/russh/examples/client_exec_interactive.rs b/russh/examples/client_exec_interactive.rs new file mode 100644 index 00000000..2c088bc8 --- /dev/null +++ b/russh/examples/client_exec_interactive.rs @@ -0,0 +1,231 @@ +/// +/// Run this example with: +/// cargo run --example client_exec_interactive -- -k +/// +use std::convert::TryFrom; +use std::env; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use clap::Parser; +use log::info; +use russh::keys::*; +use russh::*; +use termion::raw::IntoRawMode; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::ToSocketAddrs; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .init(); + + // CLI options are defined later in this file + let cli = Cli::parse(); + + info!("Connecting to {}:{}", cli.host, cli.port); + info!("Key path: {:?}", cli.private_key); + info!("OpenSSH Certificate path: {:?}", cli.openssh_certificate); + + // Session is a wrapper around a russh client, defined down below + let mut ssh = Session::connect( + cli.private_key, + cli.username.unwrap_or("root".to_string()), + cli.openssh_certificate, + (cli.host, cli.port), + ) + .await?; + info!("Connected"); + + let code = { + // We're using `termion` to put the terminal into raw mode, so that we can + // display the output of interactive applications correctly + let _raw_term = std::io::stdout().into_raw_mode()?; + ssh.call( + &cli.command + .into_iter() + .map(|x| shell_escape::escape(x.into())) // arguments are escaped manually since the SSH protocol doesn't support quoting + .collect::>() + .join(" "), + ) + .await? + }; + + println!("Exitcode: {:?}", code); + ssh.close().await?; + Ok(()) +} + +struct Client {} + +// More SSH event handlers +// can be defined in this trait +// In this example, we're only using Channel, so these aren't needed. +impl client::Handler for Client { + type Error = russh::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &ssh_key::PublicKey, + ) -> Result { + Ok(true) + } +} + +/// This struct is a convenience wrapper +/// around a russh client +/// that handles the input/output event loop +pub struct Session { + session: client::Handle, +} + +impl Session { + async fn connect, A: ToSocketAddrs>( + key_path: P, + user: impl Into, + openssh_cert_path: Option

, + addrs: A, + ) -> Result { + let key_pair = load_secret_key(key_path, None)?; + + // load ssh certificate + let mut openssh_cert = None; + if openssh_cert_path.is_some() { + openssh_cert = Some(load_openssh_certificate(openssh_cert_path.unwrap())?); + } + + let config = client::Config { + inactivity_timeout: Some(Duration::from_secs(5)), + ..<_>::default() + }; + + let config = Arc::new(config); + let sh = Client {}; + + let mut session = client::connect(config, addrs, sh).await?; + + // use publickey authentication, with or without certificate + if openssh_cert.is_none() { + let auth_res = session + .authenticate_publickey( + user, + PrivateKeyWithHashAlg::new( + Arc::new(key_pair), + session.best_supported_rsa_hash().await?.flatten(), + ), + ) + .await?; + + if !auth_res.success() { + anyhow::bail!("Authentication (with publickey) failed"); + } + } else { + let auth_res = session + .authenticate_openssh_cert(user, Arc::new(key_pair), openssh_cert.unwrap()) + .await?; + + if !auth_res.success() { + anyhow::bail!("Authentication (with publickey+cert) failed"); + } + } + + Ok(Self { session }) + } + + async fn call(&mut self, command: &str) -> Result { + let mut channel = self.session.channel_open_session().await?; + + // This example doesn't terminal resizing after the connection is established + let (w, h) = termion::terminal_size()?; + + // Request an interactive PTY from the server + channel + .request_pty( + false, + &env::var("TERM").unwrap_or("xterm".into()), + w as u32, + h as u32, + 0, + 0, + &[], // ideally you want to pass the actual terminal modes here + ) + .await?; + channel.exec(true, command).await?; + + let code; + let mut stdin = tokio_fd::AsyncFd::try_from(0)?; + let mut stdout = tokio_fd::AsyncFd::try_from(1)?; + let mut buf = vec![0; 1024]; + let mut stdin_closed = false; + + loop { + // Handle one of the possible events: + tokio::select! { + // There's terminal input available from the user + r = stdin.read(&mut buf), if !stdin_closed => { + match r { + Ok(0) => { + stdin_closed = true; + channel.eof().await?; + }, + // Send it to the server + Ok(n) => channel.data(&buf[..n]).await?, + Err(e) => return Err(e.into()), + }; + }, + // There's an event available on the session channel + Some(msg) = channel.wait() => { + match msg { + // Write data to the terminal + ChannelMsg::Data { ref data } => { + stdout.write_all(data).await?; + stdout.flush().await?; + } + // The command has returned an exit code + ChannelMsg::ExitStatus { exit_status } => { + code = exit_status; + if !stdin_closed { + channel.eof().await?; + } + break; + } + _ => {} + } + }, + } + } + Ok(code) + } + + async fn close(&mut self) -> Result<()> { + self.session + .disconnect(Disconnect::ByApplication, "", "English") + .await?; + Ok(()) + } +} + +#[derive(clap::Parser)] +#[clap(trailing_var_arg = true)] +pub struct Cli { + #[clap(index = 1)] + host: String, + + #[clap(long, short, default_value_t = 22)] + port: u16, + + #[clap(long, short)] + username: Option, + + #[clap(long, short = 'k')] + private_key: PathBuf, + + #[clap(long, short = 'o')] + openssh_certificate: Option, + + #[clap(multiple = true, index = 2, required = true)] + command: Vec, +} diff --git a/russh/examples/client_exec_simple.rs b/russh/examples/client_exec_simple.rs new file mode 100644 index 00000000..694b27b6 --- /dev/null +++ b/russh/examples/client_exec_simple.rs @@ -0,0 +1,194 @@ +use std::borrow::Cow; +/// +/// Run this example with: +/// cargo run --example client_exec_simple -- -k +/// +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use clap::Parser; +use log::info; +use russh::keys::*; +use russh::*; +use tokio::io::AsyncWriteExt; +use tokio::net::ToSocketAddrs; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_level(log::LevelFilter::Debug) + .init(); + + // CLI options are defined later in this file + let cli = Cli::parse(); + + info!("Connecting to {}:{}", cli.host, cli.port); + info!("Key path: {:?}", cli.private_key); + info!("OpenSSH Certificate path: {:?}", cli.openssh_certificate); + + // Session is a wrapper around a russh client, defined down below + let mut ssh = Session::connect( + cli.private_key, + cli.username.unwrap_or("root".to_string()), + cli.openssh_certificate, + (cli.host, cli.port), + ) + .await?; + info!("Connected"); + + let code = ssh + .call( + &cli.command + .into_iter() + .map(|x| shell_escape::escape(x.into())) // arguments are escaped manually since the SSH protocol doesn't support quoting + .collect::>() + .join(" "), + ) + .await?; + + println!("Exitcode: {:?}", code); + ssh.close().await?; + Ok(()) +} + +struct Client {} + +// More SSH event handlers +// can be defined in this trait +// In this example, we're only using Channel, so these aren't needed. +impl client::Handler for Client { + type Error = russh::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &ssh_key::PublicKey, + ) -> Result { + Ok(true) + } +} + +/// This struct is a convenience wrapper +/// around a russh client +pub struct Session { + session: client::Handle, +} + +impl Session { + async fn connect, A: ToSocketAddrs>( + key_path: P, + user: impl Into, + openssh_cert_path: Option

, + addrs: A, + ) -> Result { + let key_pair = load_secret_key(key_path, None)?; + + // load ssh certificate + let mut openssh_cert = None; + if openssh_cert_path.is_some() { + openssh_cert = Some(load_openssh_certificate(openssh_cert_path.unwrap())?); + } + + let config = client::Config { + inactivity_timeout: Some(Duration::from_secs(5)), + preferred: Preferred { + kex: Cow::Owned(vec![ + russh::kex::CURVE25519_PRE_RFC_8731, + russh::kex::EXTENSION_SUPPORT_AS_CLIENT, + ]), + ..Default::default() + }, + ..<_>::default() + }; + + let config = Arc::new(config); + let sh = Client {}; + + let mut session = client::connect(config, addrs, sh).await?; + // use publickey authentication, with or without certificate + if openssh_cert.is_none() { + let auth_res = session + .authenticate_publickey( + user, + PrivateKeyWithHashAlg::new( + Arc::new(key_pair), + session.best_supported_rsa_hash().await?.flatten(), + ), + ) + .await?; + + if !auth_res.success() { + anyhow::bail!("Authentication (with publickey) failed"); + } + } else { + let auth_res = session + .authenticate_openssh_cert(user, Arc::new(key_pair), openssh_cert.unwrap()) + .await?; + + if !auth_res.success() { + anyhow::bail!("Authentication (with publickey+cert) failed"); + } + } + + Ok(Self { session }) + } + + async fn call(&mut self, command: &str) -> Result { + let mut channel = self.session.channel_open_session().await?; + channel.exec(true, command).await?; + + let mut code = None; + let mut stdout = tokio::io::stdout(); + + loop { + // There's an event available on the session channel + let Some(msg) = channel.wait().await else { + break; + }; + match msg { + // Write data to the terminal + ChannelMsg::Data { ref data } => { + stdout.write_all(data).await?; + stdout.flush().await?; + } + // The command has returned an exit code + ChannelMsg::ExitStatus { exit_status } => { + code = Some(exit_status); + // cannot leave the loop immediately, there might still be more data to receive + } + _ => {} + } + } + Ok(code.expect("program did not exit cleanly")) + } + + async fn close(&mut self) -> Result<()> { + self.session + .disconnect(Disconnect::ByApplication, "", "English") + .await?; + Ok(()) + } +} + +#[derive(clap::Parser)] +#[clap(trailing_var_arg = true)] +pub struct Cli { + #[clap(index = 1)] + host: String, + + #[clap(long, short, default_value_t = 22)] + port: u16, + + #[clap(long, short)] + username: Option, + + #[clap(long, short = 'k')] + private_key: PathBuf, + + #[clap(long, short = 'o')] + openssh_certificate: Option, + + #[clap(multiple = true, index = 2, required = true)] + command: Vec, +} diff --git a/russh/examples/client_open_direct_tcpip.rs b/russh/examples/client_open_direct_tcpip.rs new file mode 100644 index 00000000..db29b594 --- /dev/null +++ b/russh/examples/client_open_direct_tcpip.rs @@ -0,0 +1,211 @@ +use std::net::SocketAddr; +/// +/// Run this example with: +/// cargo run --example client_open_direct_tcpip -- --private-key --local-addr --forward-addr +/// +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use anyhow::Result; +use clap::Parser; +use key::PrivateKeyWithHashAlg; +use log::info; +use russh::keys::*; +use russh::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .init(); + + // CLI options are defined later in this file + let cli = Cli::parse(); + + info!("Connecting to server: {}:{}", cli.host, cli.port); + info!("Key path: {:?}", cli.private_key); + info!("OpenSSH Certificate path: {:?}", cli.openssh_certificate); + + let forward_addr: SocketAddr = cli.forward_addr.parse()?; + let listener = TcpListener::bind(&cli.local_addr).await?; + info!("listen on: {}", &cli.local_addr); + + // Session is a wrapper around a russh client, defined down below + let mut ssh = Session::connect( + cli.private_key, + cli.openssh_certificate, + cli.username.unwrap_or("root".to_string()), + (cli.host.clone(), cli.port), + ) + .await?; + info!("Server: {}:{} Connected", cli.host, cli.port); + + let (socket, o_addr) = listener.accept().await?; + info!("originator address: {}", o_addr); + ssh.call(socket, o_addr, forward_addr).await?; + + ssh.close().await?; + + Ok(()) +} + +struct Client {} + +// More SSH event handlers +// can be defined in this trait +// In this example, we're only using Channel, so these aren't needed. +impl client::Handler for Client { + type Error = russh::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &ssh_key::PublicKey, + ) -> Result { + Ok(true) + } +} + +/// This struct is a convenience wrapper +/// around a russh client +pub struct Session { + session: client::Handle, +} + +impl Session { + async fn connect, A: ToSocketAddrs>( + key_path: P, + openssh_cert_path: Option

, + user: impl Into, + addrs: A, + ) -> Result { + let key_pair = load_secret_key(key_path, None)?; + let config = client::Config::default(); + // load ssh certificate + let mut openssh_cert = None; + if openssh_cert_path.is_some() { + openssh_cert = Some(load_openssh_certificate(openssh_cert_path.unwrap())?); + } + + let config = Arc::new(config); + let sh = Client {}; + + let mut session = client::connect(config, addrs, sh).await?; + // use publickey authentication, with or without certificate + if openssh_cert.is_none() { + let auth_res = session + .authenticate_publickey( + user, + PrivateKeyWithHashAlg::new( + Arc::new(key_pair), + session.best_supported_rsa_hash().await?.flatten(), + ), + ) + .await?; + + if !auth_res.success() { + anyhow::bail!("Authentication (with publickey) failed"); + } + } else { + let auth_res = session + .authenticate_openssh_cert(user, Arc::new(key_pair), openssh_cert.unwrap()) + .await?; + + if !auth_res.success() { + anyhow::bail!("Authentication (with publickey+cert) failed"); + } + } + + Ok(Self { session }) + } + + async fn call( + &mut self, + mut stream: TcpStream, + originator_addr: SocketAddr, + forward_addr: SocketAddr, + ) -> Result<()> { + let mut channel = self + .session + .channel_open_direct_tcpip( + forward_addr.ip().to_string(), + forward_addr.port().into(), + originator_addr.ip().to_string(), + originator_addr.port().into(), + ) + .await?; + // There's an event available on the session channel + let mut stream_closed = false; + let mut buf = vec![0; 65536]; + loop { + // Handle one of the possible events: + tokio::select! { + // There's socket input available from the client + r = stream.read(&mut buf), if !stream_closed => { + match r { + Ok(0) => { + stream_closed = true; + channel.eof().await?; + }, + // Send it to the server + Ok(n) => channel.data(&buf[..n]).await?, + Err(e) => return Err(e.into()), + }; + }, + // There's an event available on the session channel + Some(msg) = channel.wait() => { + match msg { + // Write data to the client + ChannelMsg::Data { ref data } => { + stream.write_all(data).await?; + } + ChannelMsg::Eof => { + if !stream_closed { + channel.eof().await?; + } + break; + } + ChannelMsg::WindowAdjusted { new_size:_ }=> { + // Ignore this message type + } + _ => {todo!()} + } + }, + } + } + Ok(()) + } + + async fn close(&mut self) -> Result<()> { + self.session + .disconnect(Disconnect::ByApplication, "", "English") + .await?; + Ok(()) + } +} + +#[derive(clap::Parser)] +#[clap(trailing_var_arg = true)] +pub struct Cli { + #[clap(index = 1)] + host: String, + + #[clap(long, short, default_value_t = 22)] + port: u16, + + #[clap(long, short = 'o')] + openssh_certificate: Option, + + #[clap(long, short)] + username: Option, + + #[clap(long, short = 'k')] + private_key: PathBuf, + + #[clap(long, short = 'l')] + local_addr: String, + + #[clap(long, short = 'f')] + forward_addr: String, +} diff --git a/russh/examples/echoserver.rs b/russh/examples/echoserver.rs index 64dbae92..6123297b 100644 --- a/russh/examples/echoserver.rs +++ b/russh/examples/echoserver.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; use std::sync::Arc; -use async_trait::async_trait; -use russh::server::{Msg, Session}; +use rand_core::OsRng; +use russh::keys::{Certificate, *}; +use russh::server::{Msg, Server as _, Session}; use russh::*; -use russh_keys::*; use tokio::sync::Mutex; #[tokio::main] @@ -17,29 +17,33 @@ async fn main() { inactivity_timeout: Some(std::time::Duration::from_secs(3600)), auth_rejection_time: std::time::Duration::from_secs(3), auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), - keys: vec![russh_keys::key::KeyPair::generate_ed25519().unwrap()], + keys: vec![ + russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap(), + ], + preferred: Preferred { + // kex: std::borrow::Cow::Owned(vec![russh::kex::DH_GEX_SHA256]), + ..Preferred::default() + }, ..Default::default() }; let config = Arc::new(config); - let sh = Server { + let mut sh = Server { clients: Arc::new(Mutex::new(HashMap::new())), id: 0, }; - russh::server::run(config, ("0.0.0.0", 2222), sh) - .await - .unwrap(); + sh.run_on_address(config, ("0.0.0.0", 2222)).await.unwrap(); } #[derive(Clone)] struct Server { - clients: Arc>>, + clients: Arc>>, id: usize, } impl Server { async fn post(&mut self, data: CryptoVec) { let mut clients = self.clients.lock().await; - for ((id, channel), ref mut s) in clients.iter_mut() { + for (id, (channel, ref mut s)) in clients.iter_mut() { if *id != self.id { let _ = s.data(*channel, data.clone()).await; } @@ -54,61 +58,87 @@ impl server::Server for Server { self.id += 1; s } + fn handle_session_error(&mut self, _error: ::Error) { + eprintln!("Session error: {:#?}", _error); + } } -#[async_trait] impl server::Handler for Server { - type Error = anyhow::Error; + type Error = russh::Error; async fn channel_open_session( - self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + session: &mut Session, + ) -> Result { { let mut clients = self.clients.lock().await; - clients.insert((self.id, channel.id()), session.handle()); + clients.insert(self.id, (channel.id(), session.handle())); } - Ok((self, true, session)) + Ok(true) } async fn auth_publickey( - self, + &mut self, _: &str, - _: &key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) + _key: &ssh_key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + + async fn auth_openssh_certificate( + &mut self, + _user: &str, + _certificate: &Certificate, + ) -> Result { + Ok(server::Auth::Accept) } async fn data( - mut self, + &mut self, channel: ChannelId, data: &[u8], - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { + // Sending Ctrl+C ends the session and disconnects the client + if data == [3] { + return Err(russh::Error::Disconnect); + } + let data = CryptoVec::from(format!("Got data: {}\r\n", String::from_utf8_lossy(data))); self.post(data.clone()).await; - session.data(channel, data); - Ok((self, session)) + session.data(channel, data)?; + Ok(()) } async fn tcpip_forward( - self, + &mut self, address: &str, port: &mut u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + session: &mut Session, + ) -> Result { let handle = session.handle(); let address = address.to_string(); let port = *port; tokio::spawn(async move { - let mut channel = handle + let channel = handle .channel_open_forwarded_tcpip(address, port, "1.2.3.4", 1234) .await .unwrap(); let _ = channel.data(&b"Hello from a forwarded port"[..]).await; let _ = channel.eof().await; }); - Ok((self, true, session)) + Ok(true) + } +} + +impl Drop for Server { + fn drop(&mut self) { + let id = self.id; + let clients = self.clients.clone(); + tokio::spawn(async move { + let mut clients = clients.lock().await; + clients.remove(&id); + }); } } diff --git a/russh/examples/ratatui_app.rs b/russh/examples/ratatui_app.rs new file mode 100644 index 00000000..757c976e --- /dev/null +++ b/russh/examples/ratatui_app.rs @@ -0,0 +1,271 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use rand_core::OsRng; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::Rect; +use ratatui::style::{Color, Style}; +use ratatui::widgets::{Block, Borders, Clear, Paragraph}; +use ratatui::{Terminal, TerminalOptions, Viewport}; +use russh::keys::ssh_key::PublicKey; +use russh::server::*; +use russh::{Channel, ChannelId, Pty}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::Mutex; + +type SshTerminal = Terminal>; + +struct App { + pub counter: usize, +} + +impl App { + pub fn new() -> App { + Self { counter: 0 } + } +} + +struct TerminalHandle { + sender: UnboundedSender>, + // The sink collects the data which is finally sent to sender. + sink: Vec, +} + +impl TerminalHandle { + async fn start(handle: Handle, channel_id: ChannelId) -> Self { + let (sender, mut receiver) = unbounded_channel::>(); + tokio::spawn(async move { + while let Some(data) = receiver.recv().await { + let result = handle.data(channel_id, data.into()).await; + if result.is_err() { + eprintln!("Failed to send data: {:?}", result); + } + } + }); + Self { + sender, + sink: Vec::new(), + } + } +} + +// The crossterm backend writes to the terminal handle. +impl std::io::Write for TerminalHandle { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.sink.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + let result = self.sender.send(self.sink.clone()); + if result.is_err() { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + result.unwrap_err(), + )); + } + + self.sink.clear(); + Ok(()) + } +} + +#[derive(Clone)] +struct AppServer { + clients: Arc>>, + id: usize, +} + +impl AppServer { + pub fn new() -> Self { + Self { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + } + } + + pub async fn run(&mut self) -> Result<(), anyhow::Error> { + let clients = self.clients.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + for (_, (terminal, app)) in clients.lock().await.iter_mut() { + app.counter += 1; + + terminal + .draw(|f| { + let area = f.area(); + f.render_widget(Clear, area); + let style = match app.counter % 3 { + 0 => Style::default().fg(Color::Red), + 1 => Style::default().fg(Color::Green), + _ => Style::default().fg(Color::Blue), + }; + let paragraph = Paragraph::new(format!("Counter: {}", app.counter)) + .alignment(ratatui::layout::Alignment::Center) + .style(style); + let block = Block::default() + .title("Press 'c' to reset the counter!") + .borders(Borders::ALL); + f.render_widget(paragraph.block(block), area); + }) + .unwrap(); + } + } + }); + + let config = Config { + inactivity_timeout: Some(std::time::Duration::from_secs(3600)), + auth_rejection_time: std::time::Duration::from_secs(3), + auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), + keys: vec![ + russh::keys::PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(), + ], + nodelay: true, + ..Default::default() + }; + + self.run_on_address(Arc::new(config), ("0.0.0.0", 2222)) + .await?; + Ok(()) + } +} + +impl Server for AppServer { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } +} + +impl Handler for AppServer { + type Error = anyhow::Error; + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + let terminal_handle = TerminalHandle::start(session.handle(), channel.id()).await; + + let backend = CrosstermBackend::new(terminal_handle); + + // the correct viewport area will be set when the client request a pty + let options = TerminalOptions { + viewport: Viewport::Fixed(Rect::default()), + }; + + let terminal = Terminal::with_options(backend, options)?; + let app = App::new(); + + let mut clients = self.clients.lock().await; + clients.insert(self.id, (terminal, app)); + + Ok(true) + } + + async fn auth_publickey(&mut self, _: &str, _: &PublicKey) -> Result { + Ok(Auth::Accept) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + match data { + // Pressing 'q' closes the connection. + b"q" => { + self.clients.lock().await.remove(&self.id); + session.close(channel)?; + } + // Pressing 'c' resets the counter for the app. + // Only the client with the id sees the counter reset. + b"c" => { + let mut clients = self.clients.lock().await; + let (_, app) = clients.get_mut(&self.id).unwrap(); + app.counter = 0; + } + _ => {} + } + + Ok(()) + } + + /// The client's window size has changed. + async fn window_change_request( + &mut self, + _: ChannelId, + col_width: u32, + row_height: u32, + _: u32, + _: u32, + _: &mut Session, + ) -> Result<(), Self::Error> { + let rect = Rect { + x: 0, + y: 0, + width: col_width as u16, + height: row_height as u16, + }; + + let mut clients = self.clients.lock().await; + let (terminal, _) = clients.get_mut(&self.id).unwrap(); + terminal.resize(rect)?; + + Ok(()) + } + + /// The client requests a pseudo-terminal with the given + /// specifications. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. + async fn pty_request( + &mut self, + channel: ChannelId, + _: &str, + col_width: u32, + row_height: u32, + _: u32, + _: u32, + _: &[(Pty, u32)], + session: &mut Session, + ) -> Result<(), Self::Error> { + let rect = Rect { + x: 0, + y: 0, + width: col_width as u16, + height: row_height as u16, + }; + + let mut clients = self.clients.lock().await; + let (terminal, _) = clients.get_mut(&self.id).unwrap(); + terminal.resize(rect)?; + + session.channel_success(channel)?; + + Ok(()) + } +} + +impl Drop for AppServer { + fn drop(&mut self) { + let id = self.id; + let clients = self.clients.clone(); + tokio::spawn(async move { + let mut clients = clients.lock().await; + clients.remove(&id); + }); + } +} + +#[tokio::main] +async fn main() { + let mut server = AppServer::new(); + server.run().await.expect("Failed running server"); +} diff --git a/russh/examples/ratatui_shared_app.rs b/russh/examples/ratatui_shared_app.rs new file mode 100644 index 00000000..d221c382 --- /dev/null +++ b/russh/examples/ratatui_shared_app.rs @@ -0,0 +1,270 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use rand_core::OsRng; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::Rect; +use ratatui::style::{Color, Style}; +use ratatui::widgets::{Block, Borders, Clear, Paragraph}; +use ratatui::{Terminal, TerminalOptions, Viewport}; +use russh::keys::ssh_key::PublicKey; +use russh::server::*; +use russh::{Channel, ChannelId, Pty}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::Mutex; + +type SshTerminal = Terminal>; + +struct App { + pub counter: usize, +} + +impl App { + pub fn new() -> App { + Self { counter: 0 } + } +} + +struct TerminalHandle { + sender: UnboundedSender>, + // The sink collects the data which is finally sent to sender. + sink: Vec, +} + +impl TerminalHandle { + async fn start(handle: Handle, channel_id: ChannelId) -> Self { + let (sender, mut receiver) = unbounded_channel::>(); + tokio::spawn(async move { + while let Some(data) = receiver.recv().await { + let result = handle.data(channel_id, data.into()).await; + if result.is_err() { + eprintln!("Failed to send data: {:?}", result); + } + } + }); + Self { + sender, + sink: Vec::new(), + } + } +} + +// The crossterm backend writes to the terminal handle. +impl std::io::Write for TerminalHandle { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.sink.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + let result = self.sender.send(self.sink.clone()); + if result.is_err() { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + result.unwrap_err(), + )); + } + + self.sink.clear(); + Ok(()) + } +} + +#[derive(Clone)] +struct AppServer { + clients: Arc>>, + id: usize, + app: Arc>, +} + +impl AppServer { + pub fn new() -> Self { + Self { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + app: Arc::new(Mutex::new(App::new())), + } + } + + pub async fn run(&mut self) -> Result<(), anyhow::Error> { + let app = self.app.clone(); + let clients = self.clients.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + app.lock().await.counter += 1; + let counter = app.lock().await.counter; + for (_, terminal) in clients.lock().await.iter_mut() { + terminal + .draw(|f| { + let area = f.area(); + f.render_widget(Clear, area); + let style = match counter % 3 { + 0 => Style::default().fg(Color::Red), + 1 => Style::default().fg(Color::Green), + _ => Style::default().fg(Color::Blue), + }; + let paragraph = Paragraph::new(format!("Counter: {counter}")) + .alignment(ratatui::layout::Alignment::Center) + .style(style); + let block = Block::default() + .title("Press 'c' to reset the counter!") + .borders(Borders::ALL); + f.render_widget(paragraph.block(block), area); + }) + .unwrap(); + } + } + }); + + let config = Config { + inactivity_timeout: Some(std::time::Duration::from_secs(3600)), + auth_rejection_time: std::time::Duration::from_secs(3), + auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), + keys: vec![ + russh::keys::PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(), + ], + nodelay: true, + ..Default::default() + }; + + self.run_on_address(Arc::new(config), ("0.0.0.0", 2222)) + .await?; + Ok(()) + } +} + +impl Server for AppServer { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } +} + +impl Handler for AppServer { + type Error = anyhow::Error; + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + let terminal_handle = TerminalHandle::start(session.handle(), channel.id()).await; + + let backend = CrosstermBackend::new(terminal_handle); + + // the correct viewport area will be set when the client request a pty + let options = TerminalOptions { + viewport: Viewport::Fixed(Rect::default()), + }; + + let terminal = Terminal::with_options(backend, options)?; + + let mut clients = self.clients.lock().await; + clients.insert(self.id, terminal); + + Ok(true) + } + + async fn auth_publickey(&mut self, _: &str, _: &PublicKey) -> Result { + Ok(Auth::Accept) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + let app = self.app.clone(); + match data { + // Pressing 'q' closes the connection. + b"q" => { + self.clients.lock().await.remove(&self.id); + session.close(channel)?; + } + // Pressing 'c' resets the counter for the app. + // Every client sees the counter reset. + b"c" => { + app.lock().await.counter = 0; + } + _ => {} + } + + Ok(()) + } + + /// The client's pseudo-terminal window size has changed. + async fn window_change_request( + &mut self, + _: ChannelId, + col_width: u32, + row_height: u32, + _: u32, + _: u32, + _: &mut Session, + ) -> Result<(), Self::Error> { + let rect = Rect { + x: 0, + y: 0, + width: col_width as u16, + height: row_height as u16, + }; + + let mut clients = self.clients.lock().await; + clients.get_mut(&self.id).unwrap().resize(rect)?; + + Ok(()) + } + + /// The client requests a pseudo-terminal with the given + /// specifications. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. + async fn pty_request( + &mut self, + channel: ChannelId, + _: &str, + col_width: u32, + row_height: u32, + _: u32, + _: u32, + _: &[(Pty, u32)], + session: &mut Session, + ) -> Result<(), Self::Error> { + let rect = Rect { + x: 0, + y: 0, + width: col_width as u16, + height: row_height as u16, + }; + + let mut clients = self.clients.lock().await; + let terminal = clients.get_mut(&self.id).unwrap(); + terminal.resize(rect)?; + + session.channel_success(channel)?; + + Ok(()) + } +} + +impl Drop for AppServer { + fn drop(&mut self) { + let id = self.id; + let clients = self.clients.clone(); + tokio::spawn(async move { + let mut clients = clients.lock().await; + clients.remove(&id); + }); + } +} + +#[tokio::main] +async fn main() { + let mut server = AppServer::new(); + server.run().await.expect("Failed running server"); +} diff --git a/russh/examples/remote_shell_call.rs b/russh/examples/remote_shell_call.rs deleted file mode 100644 index b77750b4..00000000 --- a/russh/examples/remote_shell_call.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::io::Write; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use async_trait::async_trait; -use log::info; -use russh::*; -use russh_keys::*; -use tokio::net::ToSocketAddrs; - -#[tokio::main] -async fn main() -> Result<()> { - env_logger::builder() - .filter_level(log::LevelFilter::Debug) - .init(); - - let args: Vec = std::env::args().collect(); - let (host, key) = match args.get(1..3) { - Some(args) => (&args[0], &args[1]), - None => { - eprintln!("Usage: {} ", args[0]); - std::process::exit(1); - } - }; - - info!("Connecting to {host}"); - info!("Key path: {key}"); - - let mut ssh = Session::connect(key, "root", host).await?; - let r = ssh.call("whoami").await?; - assert!(r.success()); - println!("Result: {}", r.output()); - ssh.close().await?; - Ok(()) -} - -struct Client {} - -#[async_trait] -impl client::Handler for Client { - type Error = russh::Error; - - async fn check_server_key( - self, - _server_public_key: &key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } -} - -pub struct Session { - session: client::Handle, -} - -impl Session { - async fn connect, A: ToSocketAddrs>( - key_path: P, - user: impl Into, - addrs: A, - ) -> Result { - let key_pair = load_secret_key(key_path, None)?; - let config = client::Config { - inactivity_timeout: Some(Duration::from_secs(5)), - ..<_>::default() - }; - let config = Arc::new(config); - let sh = Client {}; - let mut session = client::connect(config, addrs, sh).await?; - let _auth_res = session - .authenticate_publickey(user, Arc::new(key_pair)) - .await?; - Ok(Self { session }) - } - - async fn call(&mut self, command: &str) -> Result { - let mut channel = self.session.channel_open_session().await?; - channel.exec(true, command).await?; - let mut output = Vec::new(); - let mut code = None; - while let Some(msg) = channel.wait().await { - match msg { - russh::ChannelMsg::Data { ref data } => { - output.write_all(data).unwrap(); - } - russh::ChannelMsg::ExitStatus { exit_status } => { - code = Some(exit_status); - } - _ => {} - } - } - Ok(CommandResult { output, code }) - } - - async fn close(&mut self) -> Result<()> { - self.session - .disconnect(Disconnect::ByApplication, "", "English") - .await?; - Ok(()) - } -} - -struct CommandResult { - output: Vec, - code: Option, -} - -impl CommandResult { - fn output(&self) -> String { - String::from_utf8_lossy(&self.output).into() - } - - fn success(&self) -> bool { - self.code == Some(0) - } -} diff --git a/russh/examples/sftp_client.rs b/russh/examples/sftp_client.rs new file mode 100644 index 00000000..d10f0dd6 --- /dev/null +++ b/russh/examples/sftp_client.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use log::{error, info, LevelFilter}; +use russh::keys::*; +use russh::*; +use russh_sftp::client::SftpSession; +use russh_sftp::protocol::OpenFlags; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + +struct Client; + +impl client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key( + &mut self, + server_public_key: &ssh_key::PublicKey, + ) -> Result { + info!("check_server_key: {:?}", server_public_key); + Ok(true) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + _session: &mut client::Session, + ) -> Result<(), Self::Error> { + info!("data on channel {:?}: {}", channel, data.len()); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + env_logger::builder() + .filter_level(LevelFilter::Debug) + .init(); + + let config = russh::client::Config::default(); + let sh = Client {}; + let mut session = russh::client::connect(Arc::new(config), ("localhost", 22), sh) + .await + .unwrap(); + if session + .authenticate_password("root", "password") + .await + .unwrap() + .success() + { + let channel = session.channel_open_session().await.unwrap(); + channel.request_subsystem(true, "sftp").await.unwrap(); + let sftp = SftpSession::new(channel.into_stream()).await.unwrap(); + info!("current path: {:?}", sftp.canonicalize(".").await.unwrap()); + + // create dir and symlink + let path = "./some_kind_of_dir"; + let symlink = "./symlink"; + + sftp.create_dir(path).await.unwrap(); + sftp.symlink(path, symlink).await.unwrap(); + + info!("dir info: {:?}", sftp.metadata(path).await.unwrap()); + info!( + "symlink info: {:?}", + sftp.symlink_metadata(path).await.unwrap() + ); + + // scanning directory + for entry in sftp.read_dir(".").await.unwrap() { + info!("file in directory: {:?}", entry.file_name()); + } + + sftp.remove_file(symlink).await.unwrap(); + sftp.remove_dir(path).await.unwrap(); + + // interaction with i/o + let filename = "test_new.txt"; + let mut file = sftp + .open_with_flags( + filename, + OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ, + ) + .await + .unwrap(); + info!("metadata by handle: {:?}", file.metadata().await.unwrap()); + + file.write_all(b"magic text").await.unwrap(); + info!("flush: {:?}", file.flush().await); // or file.sync_all() + info!( + "current cursor position: {:?}", + file.stream_position().await + ); + + let mut str = String::new(); + + file.rewind().await.unwrap(); + file.read_to_string(&mut str).await.unwrap(); + file.rewind().await.unwrap(); + + info!( + "our magical contents: {}, after rewind: {:?}", + str, + file.stream_position().await + ); + + file.shutdown().await.unwrap(); + sftp.remove_file(filename).await.unwrap(); + + // should fail because handle was closed + error!("should fail: {:?}", file.read_u8().await); + } +} diff --git a/russh/examples/sftp_server.rs b/russh/examples/sftp_server.rs index 4ce70153..9ec1d33c 100644 --- a/russh/examples/sftp_server.rs +++ b/russh/examples/sftp_server.rs @@ -1,12 +1,13 @@ -use async_trait::async_trait; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + use log::{error, info, LevelFilter}; -use russh::{ - server::{Auth, Msg, Session}, - Channel, ChannelId, -}; -use russh_keys::key::KeyPair; +use rand_core::OsRng; +use russh::server::{Auth, Msg, Server as _, Session}; +use russh::{Channel, ChannelId}; use russh_sftp::protocol::{File, FileAttributes, Handle, Name, Status, StatusCode, Version}; -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; use tokio::sync::Mutex; #[derive(Clone)] @@ -39,72 +40,73 @@ impl SshSession { } } -#[async_trait] impl russh::server::Handler for SshSession { type Error = anyhow::Error; - async fn auth_password(self, user: &str, password: &str) -> Result<(Self, Auth), Self::Error> { + async fn auth_password(&mut self, user: &str, password: &str) -> Result { info!("credentials: {}, {}", user, password); - Ok((self, Auth::Accept)) + Ok(Auth::Accept) } async fn auth_publickey( - self, + &mut self, user: &str, - public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, Auth), Self::Error> { + public_key: &russh::keys::ssh_key::PublicKey, + ) -> Result { info!("credentials: {}, {:?}", user, public_key); - Ok((self, Auth::Accept)) + Ok(Auth::Accept) } async fn channel_open_session( - mut self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + _session: &mut Session, + ) -> Result { { let mut clients = self.clients.lock().await; clients.insert(channel.id(), channel); } - Ok((self, true, session)) + Ok(true) + } + + async fn channel_eof( + &mut self, + channel: ChannelId, + session: &mut Session, + ) -> Result<(), Self::Error> { + // After a client has sent an EOF, indicating that they don't want + // to send more data in this session, the channel can be closed. + session.close(channel)?; + Ok(()) } async fn subsystem_request( - mut self, + &mut self, channel_id: ChannelId, name: &str, - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { info!("subsystem: {}", name); if name == "sftp" { let channel = self.get_channel(channel_id).await; let sftp = SftpSession::default(); - session.channel_success(channel_id); + session.channel_success(channel_id)?; russh_sftp::server::run(channel.into_stream(), sftp).await; } else { - session.channel_failure(channel_id); + session.channel_failure(channel_id)?; } - Ok((self, session)) + Ok(()) } } +#[derive(Default)] struct SftpSession { version: Option, root_dir_read_done: bool, } -impl Default for SftpSession { - fn default() -> Self { - Self { - version: None, - root_dir_read_done: false, - } - } -} - -#[async_trait] impl russh_sftp::server::Handler for SftpSession { type Error = StatusCode; @@ -149,28 +151,20 @@ impl russh_sftp::server::Handler for SftpSession { return Ok(Name { id, files: vec![ - File { - filename: "foo".to_string(), - attrs: FileAttributes::default(), - }, - File { - filename: "bar".to_string(), - attrs: FileAttributes::default(), - }, + File::new("foo", FileAttributes::default()), + File::new("bar", FileAttributes::default()), ], }); } - Ok(Name { id, files: vec![] }) + // If all files have been sent to the client, respond with an EOF + Err(StatusCode::Eof) } async fn realpath(&mut self, id: u32, path: String) -> Result { info!("realpath: {}", path); Ok(Name { id, - files: vec![File { - filename: "/".to_string(), - attrs: FileAttributes::default(), - }], + files: vec![File::dummy("/")], }) } } @@ -184,24 +178,25 @@ async fn main() { let config = russh::server::Config { auth_rejection_time: Duration::from_secs(3), auth_rejection_time_initial: Some(Duration::from_secs(0)), - keys: vec![KeyPair::generate_ed25519().unwrap()], - inactivity_timeout: Some(Duration::from_secs(3600)), + keys: vec![ + russh::keys::PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(), + ], ..Default::default() }; - let server = Server; - - russh::server::run( - Arc::new(config), - ( - "0.0.0.0", - std::env::var("PORT") - .unwrap_or("22".to_string()) - .parse() - .unwrap(), - ), - server, - ) - .await - .unwrap(); + let mut server = Server; + + server + .run_on_address( + Arc::new(config), + ( + "0.0.0.0", + std::env::var("PORT") + .unwrap_or("22".to_string()) + .parse() + .unwrap(), + ), + ) + .await + .unwrap(); } diff --git a/russh/examples/test.rs b/russh/examples/test.rs index 034ea7e2..be139da0 100644 --- a/russh/examples/test.rs +++ b/russh/examples/test.rs @@ -1,11 +1,11 @@ -use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use log::debug; -use russh::server::{Auth, Msg, Session}; +use rand_core::OsRng; +use russh::keys::*; +use russh::server::{Auth, Msg, Server as _, Session}; use russh::*; -use russh_keys::*; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -14,15 +14,15 @@ async fn main() -> anyhow::Result<()> { config.auth_rejection_time = std::time::Duration::from_secs(3); config .keys - .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); + .push(russh::keys::PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); let config = Arc::new(config); - let sh = Server { + let mut sh = Server { clients: Arc::new(Mutex::new(HashMap::new())), id: 0, }; tokio::time::timeout( std::time::Duration::from_secs(60), - russh::server::run(config, ("0.0.0.0", 2222), sh), + sh.run_on_address(config, ("0.0.0.0", 2222)), ) .await .unwrap_or(Ok(()))?; @@ -46,54 +46,53 @@ impl server::Server for Server { } } -#[async_trait] impl server::Handler for Server { type Error = anyhow::Error; async fn channel_open_session( - self, + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { + _session: &mut Session, + ) -> Result { { debug!("channel open session"); let mut clients = self.clients.lock().unwrap(); clients.insert((self.id, channel.id()), channel); } - Ok((self, true, session)) + Ok(true) } /// The client requests a shell. #[allow(unused_variables)] async fn shell_request( - self, + &mut self, channel: ChannelId, - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { session.request_success(); - Ok((self, session)) + Ok(()) } async fn auth_publickey( - self, + &mut self, _: &str, - _: &key::PublicKey, - ) -> Result<(Self, Auth), Self::Error> { - Ok((self, server::Auth::Accept)) + _: &ssh_key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) } async fn data( - self, + &mut self, _channel: ChannelId, data: &[u8], - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { + session: &mut Session, + ) -> Result<(), Self::Error> { debug!("data: {data:?}"); { let mut clients = self.clients.lock().unwrap(); for ((_, _channel_id), ref mut channel) in clients.iter_mut() { - session.data(channel.id(), CryptoVec::from(data.to_vec())); + session.data(channel.id(), CryptoVec::from(data.to_vec()))?; } } - Ok((self, session)) + Ok(()) } } diff --git a/russh/src/auth.rs b/russh/src/auth.rs index e64f3a42..da2452c9 100644 --- a/russh/src/auth.rs +++ b/russh/src/auth.rs @@ -13,40 +13,156 @@ // limitations under the License. // +use std::future::Future; +use std::ops::Deref; +use std::str::FromStr; use std::sync::Arc; -use bitflags::bitflags; -use russh_cryptovec::CryptoVec; -use russh_keys::{encoding, key}; +use ssh_key::{Certificate, HashAlg, PrivateKey}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -bitflags! { - /// Set of authentication methods, represented by bit flags. - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - pub struct MethodSet: u32 { - /// The SSH `none` method (no authentication). - const NONE = 1; - /// The SSH `password` method (plaintext passwords). - const PASSWORD = 2; - /// The SSH `publickey` method (sign a challenge sent by the - /// server). - const PUBLICKEY = 4; - /// The SSH `hostbased` method (certain hostnames are allowed - /// by the server). - const HOSTBASED = 8; - /// The SSH `keyboard-interactive` method (answer to a - /// challenge, where the "challenge" can be a password prompt, - /// a bytestring to sign with a smartcard, or something else). - const KEYBOARD_INTERACTIVE = 16; +use crate::helpers::NameList; +use crate::keys::PrivateKeyWithHashAlg; +use crate::CryptoVec; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MethodKind { + None, + Password, + PublicKey, + HostBased, + KeyboardInteractive, +} + +impl From<&MethodKind> for &'static str { + fn from(value: &MethodKind) -> Self { + match value { + MethodKind::None => "none", + MethodKind::Password => "password", + MethodKind::PublicKey => "publickey", + MethodKind::HostBased => "hostbased", + MethodKind::KeyboardInteractive => "keyboard-interactive", + } + } +} + +impl FromStr for MethodKind { + fn from_str(b: &str) -> Result { + match b { + "none" => Ok(MethodKind::None), + "password" => Ok(MethodKind::Password), + "publickey" => Ok(MethodKind::PublicKey), + "hostbased" => Ok(MethodKind::HostBased), + "keyboard-interactive" => Ok(MethodKind::KeyboardInteractive), + _ => Err(()), + } + } + + type Err = (); +} + +impl From<&MethodKind> for String { + fn from(value: &MethodKind) -> Self { + <&str>::from(value).to_string() } } +/// An ordered set of authentication methods. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MethodSet(Vec); + +impl Deref for MethodSet { + type Target = [MethodKind]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From<&[MethodKind]> for MethodSet { + fn from(value: &[MethodKind]) -> Self { + let mut this = Self::empty(); + for method in value { + this.push(*method); + } + this + } +} + +impl From<&MethodSet> for NameList { + fn from(value: &MethodSet) -> Self { + Self(value.iter().map(|x| x.into()).collect()) + } +} + +impl From<&NameList> for MethodSet { + fn from(value: &NameList) -> Self { + Self( + value + .0 + .iter() + .filter_map(|x| MethodKind::from_str(x).ok()) + .collect(), + ) + } +} + +impl MethodSet { + pub fn empty() -> Self { + Self(Vec::new()) + } + + pub fn all() -> Self { + Self(vec![ + MethodKind::None, + MethodKind::Password, + MethodKind::PublicKey, + MethodKind::HostBased, + MethodKind::KeyboardInteractive, + ]) + } + + pub fn remove(&mut self, method: MethodKind) { + self.0.retain(|x| *x != method); + } + + /// Push a method to the end of the list. + /// If the method is already in the list, it is moved to the end. + pub fn push(&mut self, method: MethodKind) { + self.remove(method); + self.0.push(method); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthResult { + Success, + Failure { + /// The server suggests to proceed with these auth methods + remaining_methods: MethodSet, + /// The server says that though auth method has been accepted, + /// further authentication is required + partial_success: bool, + }, +} + +impl AuthResult { + pub fn success(&self) -> bool { + matches!(self, AuthResult::Success) + } +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] pub trait Signer: Sized { type Error: From; - type Future: futures::Future)> + Send; - fn auth_publickey_sign(self, key: &key::PublicKey, to_sign: CryptoVec) -> Self::Future; + fn auth_publickey_sign( + &mut self, + key: &ssh_key::PublicKey, + hash_alg: Option, + to_sign: CryptoVec, + ) -> impl Future> + Send; } #[derive(Debug, Error)] @@ -54,80 +170,98 @@ pub enum AgentAuthError { #[error(transparent)] Send(#[from] crate::SendError), #[error(transparent)] - Key(#[from] russh_keys::Error), + Key(#[from] crate::keys::Error), } +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] impl Signer - for russh_keys::agent::client::AgentClient + for crate::keys::agent::client::AgentClient { type Error = AgentAuthError; - #[allow(clippy::type_complexity)] - type Future = std::pin::Pin< - Box)> + Send>, - >; - fn auth_publickey_sign(self, key: &key::PublicKey, to_sign: CryptoVec) -> Self::Future { - let fut = self.sign_request(key, to_sign); - futures::FutureExt::boxed(async move { - let (a, b) = fut.await; - (a, b.map_err(AgentAuthError::Key)) - }) + + #[allow(clippy::manual_async_fn)] + fn auth_publickey_sign( + &mut self, + key: &ssh_key::PublicKey, + hash_alg: Option, + to_sign: CryptoVec, + ) -> impl Future> { + async move { + self.sign_request(key, hash_alg, to_sign) + .await + .map_err(Into::into) + } } } #[derive(Debug)] pub enum Method { None, - Password { password: String }, - PublicKey { key: Arc }, - FuturePublicKey { key: key::PublicKey }, - KeyboardInteractive { submethods: String }, + Password { + password: String, + }, + PublicKey { + key: PrivateKeyWithHashAlg, + }, + OpenSshCertificate { + key: Arc, + cert: Certificate, + }, + FuturePublicKey { + key: ssh_key::PublicKey, + hash_alg: Option, + }, + KeyboardInteractive { + submethods: String, + }, // Hostbased, } -impl encoding::Bytes for MethodSet { - fn bytes(&self) -> &'static [u8] { - match *self { - MethodSet::NONE => b"none", - MethodSet::PASSWORD => b"password", - MethodSet::PUBLICKEY => b"publickey", - MethodSet::HOSTBASED => b"hostbased", - MethodSet::KEYBOARD_INTERACTIVE => b"keyboard-interactive", - _ => b"", - } - } -} - -impl MethodSet { - pub(crate) fn from_bytes(b: &[u8]) -> Option { - match b { - b"none" => Some(MethodSet::NONE), - b"password" => Some(MethodSet::PASSWORD), - b"publickey" => Some(MethodSet::PUBLICKEY), - b"hostbased" => Some(MethodSet::HOSTBASED), - b"keyboard-interactive" => Some(MethodSet::KEYBOARD_INTERACTIVE), - _ => None, - } - } -} - #[doc(hidden)] #[derive(Debug)] pub struct AuthRequest { pub methods: MethodSet, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub partial_success: bool, pub current: Option, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub rejection_count: usize, } #[doc(hidden)] #[derive(Debug)] pub enum CurrentRequest { + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] PublicKey { + #[allow(dead_code)] key: CryptoVec, + #[allow(dead_code)] algo: CryptoVec, sent_pk_ok: bool, }, KeyboardInteractive { + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] submethods: String, }, } + +impl AuthRequest { + pub(crate) fn new(method: &Method) -> Self { + match method { + Method::KeyboardInteractive { submethods } => Self { + methods: MethodSet::all(), + partial_success: false, + current: Some(CurrentRequest::KeyboardInteractive { + submethods: submethods.to_string(), + }), + rejection_count: 0, + }, + _ => Self { + methods: MethodSet::all(), + partial_success: false, + current: None, + rejection_count: 0, + }, + } + } +} diff --git a/russh/src/cert.rs b/russh/src/cert.rs new file mode 100644 index 00000000..26cac9ba --- /dev/null +++ b/russh/src/cert.rs @@ -0,0 +1,45 @@ +use ssh_key::{Certificate, HashAlg, PublicKey}; +#[cfg(not(target_arch = "wasm32"))] +use { + crate::helpers::AlgorithmExt, ssh_encoding::Decode, ssh_key::public::KeyData, + ssh_key::Algorithm, +}; + +use crate::keys::key::PrivateKeyWithHashAlg; + +#[derive(Debug)] +pub(crate) enum PublicKeyOrCertificate { + PublicKey { + key: PublicKey, + hash_alg: Option, + }, + Certificate(Certificate), +} + +impl From<&PrivateKeyWithHashAlg> for PublicKeyOrCertificate { + fn from(key: &PrivateKeyWithHashAlg) -> Self { + PublicKeyOrCertificate::PublicKey { + key: key.public_key().clone(), + hash_alg: key.hash_alg(), + } + } +} + +impl PublicKeyOrCertificate { + #[cfg(not(target_arch = "wasm32"))] + pub fn decode(pubkey_algo: &str, buf: &[u8]) -> Result { + let mut reader = buf; + match Algorithm::new_certificate_ext(pubkey_algo) { + Ok(Algorithm::Other(_)) | Err(ssh_key::Error::Encoding(_)) => { + // Did not match a known cert algorithm + Ok(PublicKeyOrCertificate::PublicKey { + key: KeyData::decode(&mut reader)?.into(), + hash_alg: Algorithm::new(pubkey_algo)?.hash_alg(), + }) + } + _ => Ok(PublicKeyOrCertificate::Certificate(Certificate::decode( + &mut reader, + )?)), + } + } +} diff --git a/russh/src/channel_stream/mod.rs b/russh/src/channel_stream/mod.rs deleted file mode 100644 index 80ee89a6..00000000 --- a/russh/src/channel_stream/mod.rs +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// Originally from microsoft/dev-tunnels - -mod read_buffer; - -use std::io; -use std::pin::Pin; -use std::task::Poll; - -use log::debug; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc; - -use self::read_buffer::ReadBuffer; - -/// AsyncRead/AsyncWrite wrapper for SSH Channels -pub struct ChannelStream { - incoming: mpsc::UnboundedReceiver>, - outgoing: mpsc::UnboundedSender>, - - readbuf: ReadBuffer, - - is_write_fut_valid: bool, - write_fut: tokio_util::sync::ReusableBoxFuture<'static, Result<(), Vec>>, -} - -impl ChannelStream { - pub fn new() -> ( - Self, - mpsc::UnboundedReceiver>, - mpsc::UnboundedSender>, - ) { - let (w_tx, w_rx) = mpsc::unbounded_channel(); - let (r_tx, r_rx) = mpsc::unbounded_channel(); - ( - ChannelStream { - incoming: w_rx, - outgoing: r_tx, - readbuf: ReadBuffer::default(), - is_write_fut_valid: false, - write_fut: tokio_util::sync::ReusableBoxFuture::new(make_client_write_fut(None)), - }, - r_rx, - w_tx, - ) - } -} - -/// Makes a future that writes to the russh handle. This general approach was -/// taken from https://docs.rs/tokio-util/0.7.3/tokio_util/sync/struct.PollSender.html -/// This is just like make_server_write_fut, but for clients (they don't share a trait...) -async fn make_client_write_fut( - data: Option<(mpsc::UnboundedSender>, Vec)>, -) -> Result<(), Vec> { - match data { - Some((sender, data)) => sender.send(data).map_err(|e| e.0), - None => unreachable!("this future should not be pollable in this state"), - } -} - -impl AsyncWrite for ChannelStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - if !self.is_write_fut_valid { - let outgoing = self.outgoing.clone(); - self.write_fut - .set(make_client_write_fut(Some((outgoing, buf.to_vec())))); - self.is_write_fut_valid = true; - } - - self.poll_flush(cx).map(|r| r.map(|_| buf.len())) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if !self.is_write_fut_valid { - return Poll::Ready(Ok(())); - } - - match self.write_fut.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(_)) => { - self.is_write_fut_valid = false; - Poll::Ready(Ok(())) - } - Poll::Ready(Err(_)) => { - self.is_write_fut_valid = false; - debug!("ChannelStream AsyncWrite EOF"); - Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "EOF"))) - } - } - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - if let Err(err) = self.outgoing.send("".into()) { - let err = format!("{err:?}"); - return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err))) - } - Poll::Ready(Ok(())) - } -} - -impl AsyncRead for ChannelStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - if let Some((v, s)) = self.readbuf.take_data() { - return self.readbuf.put_data(buf, v, s); - } - - let x = self.incoming.poll_recv(cx); - match x { - Poll::Ready(Some(msg)) => self.readbuf.put_data(buf, msg, 0), - Poll::Ready(None) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, - } - } -} diff --git a/russh/src/channel_stream/read_buffer.rs b/russh/src/channel_stream/read_buffer.rs deleted file mode 100644 index 2521e068..00000000 --- a/russh/src/channel_stream/read_buffer.rs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// Originally from microsoft/dev-tunnels - -use std::task::Poll; - -/// Helper used when converting Future interfaces to poll-based interfaces. -/// Stores excess data that can be reused on future polls. -#[derive(Default)] -pub(crate) struct ReadBuffer(Option<(Vec, usize)>); - -impl ReadBuffer { - /// Removes any data stored in the read buffer - pub fn take_data(&mut self) -> Option<(Vec, usize)> { - self.0.take() - } - - /// Writes as many bytes as possible to the readbuf, stashing any extra. - pub fn put_data( - &mut self, - target: &mut tokio::io::ReadBuf<'_>, - bytes: Vec, - start: usize, - ) -> Poll> { - if target.remaining() >= bytes.len() - start { - if start < bytes.len() { - #[allow(clippy::indexing_slicing)] - target.put_slice(&bytes[start..]); - } - self.0 = None; - } else { - let end = start + target.remaining(); - if start < bytes.len() && end <= bytes.len() { - #[allow(clippy::indexing_slicing)] - target.put_slice(&bytes[start..end]); - } - self.0 = Some((bytes, end)); - } - - Poll::Ready(Ok(())) - } -} diff --git a/russh/src/channels.rs b/russh/src/channels.rs deleted file mode 100644 index e5da311e..00000000 --- a/russh/src/channels.rs +++ /dev/null @@ -1,413 +0,0 @@ -use russh_cryptovec::CryptoVec; -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; -use log::debug; - -use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig}; - -#[derive(Debug)] -#[non_exhaustive] -/// Possible messages that [Channel::wait] can receive. -pub enum ChannelMsg { - Open { - id: ChannelId, - max_packet_size: u32, - window_size: u32, - }, - Data { - data: CryptoVec, - }, - ExtendedData { - data: CryptoVec, - ext: u32, - }, - Eof, - /// (client only) - RequestPty { - want_reply: bool, - term: String, - col_width: u32, - row_height: u32, - pix_width: u32, - pix_height: u32, - terminal_modes: Vec<(Pty, u32)>, - }, - /// (client only) - RequestShell { - want_reply: bool, - }, - /// (client only) - Exec { - want_reply: bool, - command: Vec, - }, - /// (client only) - Signal { - signal: Sig, - }, - /// (client only) - RequestSubsystem { - want_reply: bool, - name: String, - }, - /// (client only) - RequestX11 { - want_reply: bool, - single_connection: bool, - x11_authentication_protocol: String, - x11_authentication_cookie: String, - x11_screen_number: u32, - }, - /// (client only) - SetEnv { - want_reply: bool, - variable_name: String, - variable_value: String, - }, - /// (client only) - WindowChange { - col_width: u32, - row_height: u32, - pix_width: u32, - pix_height: u32, - }, - /// (client only) - AgentForward { - want_reply: bool, - }, - - /// (server only) - XonXoff { - client_can_do: bool, - }, - /// (server only) - ExitStatus { - exit_status: u32, - }, - /// (server only) - ExitSignal { - signal_name: Sig, - core_dumped: bool, - error_message: String, - lang_tag: String, - }, - /// (server only) - WindowAdjusted { - new_size: u32, - }, - /// (server only) - Success, - /// (server only) - Failure, - /// (server only) - Close, - OpenFailure(ChannelOpenFailure), -} - -/// A handle to a session channel. -/// -/// Allows you to read and write from a channel without borrowing the session -pub struct Channel> { - pub(crate) id: ChannelId, - pub(crate) sender: Sender, - pub(crate) receiver: UnboundedReceiver, - pub(crate) max_packet_size: u32, - pub(crate) window_size: u32, -} - -impl> std::fmt::Debug for Channel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Channel").field("id", &self.id).finish() - } -} - -impl + Send + 'static> Channel { - pub fn id(&self) -> ChannelId { - self.id - } - - /// Returns the min between the maximum packet size and the - /// remaining window size in the channel. - pub fn writable_packet_size(&self) -> usize { - self.max_packet_size.min(self.window_size) as usize - } - - /// Request a pseudo-terminal with the given characteristics. - #[allow(clippy::too_many_arguments)] // length checked - pub async fn request_pty( - &mut self, - want_reply: bool, - term: &str, - col_width: u32, - row_height: u32, - pix_width: u32, - pix_height: u32, - terminal_modes: &[(Pty, u32)], - ) -> Result<(), Error> { - self.send_msg(ChannelMsg::RequestPty { - want_reply, - term: term.to_string(), - col_width, - row_height, - pix_width, - pix_height, - terminal_modes: terminal_modes.to_vec(), - }) - .await?; - Ok(()) - } - - /// Request a remote shell. - pub async fn request_shell(&mut self, want_reply: bool) -> Result<(), Error> { - self.send_msg(ChannelMsg::RequestShell { want_reply }) - .await?; - Ok(()) - } - - /// Execute a remote program (will be passed to a shell). This can - /// be used to implement scp (by calling a remote scp and - /// tunneling to its standard input). - pub async fn exec>>( - &mut self, - want_reply: bool, - command: A, - ) -> Result<(), Error> { - self.send_msg(ChannelMsg::Exec { - want_reply, - command: command.into(), - }) - .await?; - Ok(()) - } - - /// Signal a remote process. - pub async fn signal(&mut self, signal: Sig) -> Result<(), Error> { - self.send_msg(ChannelMsg::Signal { signal }).await?; - Ok(()) - } - - /// Request the start of a subsystem with the given name. - pub async fn request_subsystem>( - &mut self, - want_reply: bool, - name: A, - ) -> Result<(), Error> { - self.send_msg(ChannelMsg::RequestSubsystem { - want_reply, - name: name.into(), - }) - .await?; - Ok(()) - } - - /// Request X11 forwarding through an already opened X11 - /// channel. See - /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.3.1) - /// for security issues related to cookies. - pub async fn request_x11, B: Into>( - &mut self, - want_reply: bool, - single_connection: bool, - x11_authentication_protocol: A, - x11_authentication_cookie: B, - x11_screen_number: u32, - ) -> Result<(), Error> { - self.send_msg(ChannelMsg::RequestX11 { - want_reply, - single_connection, - x11_authentication_protocol: x11_authentication_protocol.into(), - x11_authentication_cookie: x11_authentication_cookie.into(), - x11_screen_number, - }) - .await?; - Ok(()) - } - - /// Set a remote environment variable. - pub async fn set_env, B: Into>( - &mut self, - want_reply: bool, - variable_name: A, - variable_value: B, - ) -> Result<(), Error> { - self.send_msg(ChannelMsg::SetEnv { - want_reply, - variable_name: variable_name.into(), - variable_value: variable_value.into(), - }) - .await?; - Ok(()) - } - - /// Inform the server that our window size has changed. - pub async fn window_change( - &mut self, - col_width: u32, - row_height: u32, - pix_width: u32, - pix_height: u32, - ) -> Result<(), Error> { - self.send_msg(ChannelMsg::WindowChange { - col_width, - row_height, - pix_width, - pix_height, - }) - .await?; - Ok(()) - } - - /// Inform the server that we will accept agent forwarding channels - pub async fn agent_forward(&mut self, want_reply: bool) -> Result<(), Error> { - self.send_msg(ChannelMsg::AgentForward { want_reply }) - .await?; - Ok(()) - } - - /// Send data to a channel. - pub async fn data(&mut self, data: R) -> Result<(), Error> { - self.send_data(None, data).await - } - - /// Send data to a channel. The number of bytes added to the - /// "sending pipeline" (to be processed by the event loop) is - /// returned. - pub async fn extended_data( - &mut self, - ext: u32, - data: R, - ) -> Result<(), Error> { - self.send_data(Some(ext), data).await - } - - async fn send_data( - &mut self, - ext: Option, - mut data: R, - ) -> Result<(), Error> { - let mut total = 0; - loop { - // wait for the window to be restored. - while self.window_size == 0 { - match self.receiver.recv().await { - Some(ChannelMsg::WindowAdjusted { new_size }) => { - debug!("window adjusted: {:?}", new_size); - self.window_size = new_size; - break; - } - Some(msg) => { - debug!("unexpected channel msg: {:?}", msg); - } - None => break, - } - } - debug!( - "sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}", - self.window_size, self.max_packet_size, total - ); - let sendable = self.window_size.min(self.max_packet_size) as usize; - - debug!("sendable {:?}", sendable); - - // If we can not send anymore, continue - // and wait for server window adjustment - if sendable == 0 { - continue; - } - - let mut c = CryptoVec::new_zeroed(sendable); - let n = data.read(&mut c[..]).await?; - total += n; - c.resize(n); - self.window_size -= n as u32; - self.send_data_packet(ext, c).await?; - if n == 0 { - break; - } else if self.window_size > 0 { - continue; - } - } - Ok(()) - } - - async fn send_data_packet(&mut self, ext: Option, data: CryptoVec) -> Result<(), Error> { - self.send_msg(if let Some(ext) = ext { - ChannelMsg::ExtendedData { ext, data } - } else { - ChannelMsg::Data { data } - }) - .await?; - Ok(()) - } - - pub async fn eof(&mut self) -> Result<(), Error> { - self.send_msg(ChannelMsg::Eof).await?; - Ok(()) - } - - /// Wait for data to come. - pub async fn wait(&mut self) -> Option { - match self.receiver.recv().await { - Some(ChannelMsg::WindowAdjusted { new_size }) => { - self.window_size = new_size; - Some(ChannelMsg::WindowAdjusted { new_size }) - } - Some(msg) => Some(msg), - None => None, - } - } - - async fn send_msg(&self, msg: ChannelMsg) -> Result<(), Error> { - self.sender - .send((self.id, msg).into()) - .await - .map_err(|_| Error::SendError) - } - - /// Request that the channel be closed. - pub async fn close(&self) -> Result<(), Error> { - self.send_msg(ChannelMsg::Close).await?; - Ok(()) - } - - pub fn into_stream(mut self) -> ChannelStream { - let (stream, mut r_rx, w_tx) = ChannelStream::new(); - - tokio::spawn(async move { - loop { - tokio::select! { - data = r_rx.recv() => { - match data { - Some(data) if !data.is_empty() => self.data(&data[..]).await?, - Some(_) => { - log::debug!("closing chan {:?}, received empty data", &self.id); - self.eof().await?; - self.close().await?; - break; - }, - None => { - self.close().await?; - break - } - } - }, - msg = self.wait() => { - match msg { - Some(ChannelMsg::Data { data }) => { - w_tx.send(data[..].into()).map_err(|_| crate::Error::SendError)?; - } - Some(ChannelMsg::Eof) => { - // Send a 0-length chunk to indicate EOF. - w_tx.send("".into()).map_err(|_| crate::Error::SendError)?; - break - } - None => break, - _ => (), - } - } - } - } - Ok::<_, crate::Error>(()) - }); - stream - } -} diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs new file mode 100644 index 00000000..d7f937cd --- /dev/null +++ b/russh/src/channels/channel_ref.rs @@ -0,0 +1,33 @@ +use tokio::sync::mpsc::Sender; + +use super::WindowSizeRef; +use crate::ChannelMsg; + +/// A handle to the [`super::Channel`]'s to be able to transmit messages +/// to it and update it's `window_size`. +#[derive(Debug)] +pub struct ChannelRef { + pub(super) sender: Sender, + pub(super) window_size: WindowSizeRef, +} + +impl ChannelRef { + pub fn new(sender: Sender) -> Self { + Self { + sender, + window_size: WindowSizeRef::new(0), + } + } + + pub(crate) fn window_size(&self) -> &WindowSizeRef { + &self.window_size + } +} + +impl std::ops::Deref for ChannelRef { + type Target = Sender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} diff --git a/russh/src/channels/channel_stream.rs b/russh/src/channels/channel_stream.rs new file mode 100644 index 00000000..9e8d14be --- /dev/null +++ b/russh/src/channels/channel_stream.rs @@ -0,0 +1,63 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::io::{ChannelCloseOnDrop, ChannelRx, ChannelTx}; +use super::{ChannelId, ChannelMsg}; + +/// AsyncRead/AsyncWrite wrapper for SSH Channels +pub struct ChannelStream +where + S: From<(ChannelId, ChannelMsg)> + Send + 'static, +{ + tx: ChannelTx, + rx: ChannelRx>, +} + +impl ChannelStream +where + S: From<(ChannelId, ChannelMsg)> + Send, +{ + pub(super) fn new(tx: ChannelTx, rx: ChannelRx>) -> Self { + Self { tx, rx } + } +} + +impl AsyncRead for ChannelStream +where + S: From<(ChannelId, ChannelMsg)> + Send, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.rx).poll_read(cx, buf) + } +} + +impl AsyncWrite for ChannelStream +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send + Sync, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.tx).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.tx).poll_shutdown(cx) + } +} diff --git a/russh/src/channels/io/mod.rs b/russh/src/channels/io/mod.rs new file mode 100644 index 00000000..95aeab50 --- /dev/null +++ b/russh/src/channels/io/mod.rs @@ -0,0 +1,44 @@ +mod rx; +use std::borrow::{Borrow, BorrowMut}; + +pub use rx::ChannelRx; + +mod tx; +pub use tx::ChannelTx; + +use crate::{Channel, ChannelId, ChannelMsg, ChannelReadHalf}; + +#[derive(Debug)] +pub struct ChannelCloseOnDrop + Send + 'static>(pub Channel); + +impl + Send + 'static> Borrow + for ChannelCloseOnDrop +{ + fn borrow(&self) -> &ChannelReadHalf { + &self.0.read_half + } +} + +impl + Send + 'static> BorrowMut + for ChannelCloseOnDrop +{ + fn borrow_mut(&mut self) -> &mut ChannelReadHalf { + &mut self.0.read_half + } +} + +impl + Send + 'static> Drop for ChannelCloseOnDrop { + fn drop(&mut self) { + let id = self.0.write_half.id; + let sender = self.0.write_half.sender.clone(); + + // Best effort: async drop where possible + #[cfg(not(target_arch = "wasm32"))] + tokio::spawn(async move { + let _ = sender.send((id, ChannelMsg::Close).into()).await; + }); + + #[cfg(target_arch = "wasm32")] + let _ = sender.try_send((id, ChannelMsg::Close).into()); + } +} diff --git a/russh/src/channels/io/rx.rs b/russh/src/channels/io/rx.rs new file mode 100644 index 00000000..57080db5 --- /dev/null +++ b/russh/src/channels/io/rx.rs @@ -0,0 +1,85 @@ +use std::borrow::BorrowMut; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use tokio::io::AsyncRead; + +use super::{ChannelMsg, ChannelReadHalf}; + +#[derive(Debug)] +pub struct ChannelRx { + channel: R, + buffer: Option<(ChannelMsg, usize)>, + + ext: Option, +} + +impl ChannelRx { + pub fn new(channel: R, ext: Option) -> Self { + Self { + channel, + buffer: None, + ext, + } + } +} + +impl AsyncRead for ChannelRx +where + R: BorrowMut + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let (msg, mut idx) = match self.buffer.take() { + Some(msg) => msg, + None => match ready!(self.channel.borrow_mut().receiver.poll_recv(cx)) { + Some(msg) => (msg, 0), + None => return Poll::Ready(Ok(())), + }, + }; + + match (&msg, self.ext) { + (ChannelMsg::Data { data }, None) => { + let readable = buf.remaining().min(data.len() - idx); + + // Clamped to maximum `buf.remaining()` and `data.len() - idx` with `.min` + #[allow(clippy::indexing_slicing)] + buf.put_slice(&data[idx..idx + readable]); + idx += readable; + + if idx != data.len() { + self.buffer = Some((msg, idx)); + } + + Poll::Ready(Ok(())) + } + (ChannelMsg::ExtendedData { data, ext }, Some(target)) if *ext == target => { + let readable = buf.remaining().min(data.len() - idx); + + // Clamped to maximum `buf.remaining()` and `data.len() - idx` with `.min` + #[allow(clippy::indexing_slicing)] + buf.put_slice(&data[idx..idx + readable]); + idx += readable; + + if idx != data.len() { + self.buffer = Some((msg, idx)); + } + + Poll::Ready(Ok(())) + } + (ChannelMsg::Eof, _) => { + self.channel.borrow_mut().receiver.close(); + + Poll::Ready(Ok(())) + } + _ => { + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} diff --git a/russh/src/channels/io/tx.rs b/russh/src/channels/io/tx.rs new file mode 100644 index 00000000..af9565b6 --- /dev/null +++ b/russh/src/channels/io/tx.rs @@ -0,0 +1,202 @@ +use std::convert::TryFrom; +use std::future::Future; +use std::io; +use std::num::NonZeroUsize; +use std::ops::DerefMut; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; + +use futures::FutureExt; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::{self, OwnedPermit}; +use tokio::sync::{Mutex, Notify, OwnedMutexGuard}; + +use super::ChannelMsg; +use crate::{ChannelId, CryptoVec}; + +type BoxedThreadsafeFuture = Pin>>; +type OwnedPermitFuture = + BoxedThreadsafeFuture, ChannelMsg, usize), SendError<()>>>; + +struct WatchNotification(Pin>>); + +/// A single future that becomes ready once the window size +/// changes to a positive value +impl WatchNotification { + fn new(n: Arc) -> Self { + Self(Box::pin(async move { n.notified().await })) + } +} + +impl Future for WatchNotification { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner = self.deref_mut().0.as_mut(); + ready!(inner.poll(cx)); + Poll::Ready(()) + } +} + +pub struct ChannelTx { + sender: mpsc::Sender, + send_fut: Option>, + id: ChannelId, + window_size_fut: Option>>, + window_size: Arc>, + notify: Arc, + window_size_notication: WatchNotification, + max_packet_size: u32, + ext: Option, +} + +impl ChannelTx +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send, +{ + pub fn new( + sender: mpsc::Sender, + id: ChannelId, + window_size: Arc>, + window_size_notification: Arc, + max_packet_size: u32, + ext: Option, + ) -> Self { + Self { + sender, + send_fut: None, + id, + notify: Arc::clone(&window_size_notification), + window_size_notication: WatchNotification::new(window_size_notification), + window_size, + window_size_fut: None, + max_packet_size, + ext, + } + } + + fn poll_writable(&mut self, cx: &mut Context<'_>, buf_len: usize) -> Poll { + let window_size = self.window_size.clone(); + let window_size_fut = self + .window_size_fut + .get_or_insert_with(|| Box::pin(window_size.lock_owned())); + let mut window_size = ready!(window_size_fut.poll_unpin(cx)); + self.window_size_fut.take(); + + let writable = (self.max_packet_size).min(*window_size).min(buf_len as u32) as usize; + + match NonZeroUsize::try_from(writable) { + Ok(w) => { + *window_size -= writable as u32; + if *window_size > 0 { + self.notify.notify_one(); + } + Poll::Ready(w) + } + Err(_) => { + drop(window_size); + ready!(self.window_size_notication.poll_unpin(cx)); + self.window_size_notication = WatchNotification::new(Arc::clone(&self.notify)); + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + fn poll_mk_msg( + &mut self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<(ChannelMsg, NonZeroUsize)> { + let writable = ready!(self.poll_writable(cx, buf.len())); + + let mut data = CryptoVec::new_zeroed(writable.into()); + #[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.poll_writable` + data.copy_from_slice(&buf[..writable.into()]); + data.resize(writable.into()); + + let msg = match self.ext { + None => ChannelMsg::Data { data }, + Some(ext) => ChannelMsg::ExtendedData { data, ext }, + }; + + Poll::Ready((msg, writable)) + } + + fn activate(&mut self, msg: ChannelMsg, writable: usize) -> &mut OwnedPermitFuture { + use futures::TryFutureExt; + self.send_fut.insert(Box::pin( + self.sender + .clone() + .reserve_owned() + .map_ok(move |p| (p, msg, writable)), + )) + } + + fn handle_write_result( + &mut self, + r: Result<(OwnedPermit, ChannelMsg, usize), SendError<()>>, + ) -> Result { + self.send_fut = None; + match r { + Ok((permit, msg, writable)) => { + permit.send((self.id, msg).into()); + Ok(writable) + } + Err(SendError(())) => Err(io::Error::new(io::ErrorKind::BrokenPipe, "channel closed")), + } + } +} + +impl AsyncWrite for ChannelTx +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send, +{ + #[allow(clippy::too_many_lines)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "cannot send empty buffer", + ))); + } + let send_fut = if let Some(x) = self.send_fut.as_mut() { + x + } else { + let (msg, writable) = ready!(self.poll_mk_msg(cx, buf)); + self.activate(msg, writable.into()) + }; + let r = ready!(send_fut.as_mut().poll_unpin(cx)); + Poll::Ready(self.handle_write_result(r)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let send_fut = if let Some(x) = self.send_fut.as_mut() { + x + } else { + self.activate(ChannelMsg::Eof, 0) + }; + let r = ready!(send_fut.as_mut().poll_unpin(cx)).map(|(p, _, _)| (p, ChannelMsg::Eof, 0)); + Poll::Ready(self.handle_write_result(r).map(drop)) + } +} + +impl Drop for ChannelTx { + fn drop(&mut self) { + // Allow other writers to make progress + self.notify.notify_one(); + } +} diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs new file mode 100644 index 00000000..bf0f406e --- /dev/null +++ b/russh/src/channels/mod.rs @@ -0,0 +1,648 @@ +use std::{pin::Pin, sync::Arc}; + +use futures::{Future, FutureExt as _}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::{Mutex, Notify}; + +use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig}; + +pub mod io; + +mod channel_ref; +pub use channel_ref::ChannelRef; + +mod channel_stream; +pub use channel_stream::ChannelStream; + +#[derive(Debug)] +#[non_exhaustive] +/// Possible messages that [Channel::wait] can receive. +pub enum ChannelMsg { + Open { + id: ChannelId, + max_packet_size: u32, + window_size: u32, + }, + Data { + data: CryptoVec, + }, + ExtendedData { + data: CryptoVec, + ext: u32, + }, + Eof, + Close, + /// (client only) + RequestPty { + want_reply: bool, + term: String, + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + terminal_modes: Vec<(Pty, u32)>, + }, + /// (client only) + RequestShell { + want_reply: bool, + }, + /// (client only) + Exec { + want_reply: bool, + command: Vec, + }, + /// (client only) + Signal { + signal: Sig, + }, + /// (client only) + RequestSubsystem { + want_reply: bool, + name: String, + }, + /// (client only) + RequestX11 { + want_reply: bool, + single_connection: bool, + x11_authentication_protocol: String, + x11_authentication_cookie: String, + x11_screen_number: u32, + }, + /// (client only) + SetEnv { + want_reply: bool, + variable_name: String, + variable_value: String, + }, + /// (client only) + WindowChange { + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + }, + /// (client only) + AgentForward { + want_reply: bool, + }, + + /// (server only) + XonXoff { + client_can_do: bool, + }, + /// (server only) + ExitStatus { + exit_status: u32, + }, + /// (server only) + ExitSignal { + signal_name: Sig, + core_dumped: bool, + error_message: String, + lang_tag: String, + }, + /// (server only) + WindowAdjusted { + new_size: u32, + }, + /// (server only) + Success, + /// (server only) + Failure, + OpenFailure(ChannelOpenFailure), +} + +#[derive(Clone, Debug)] +pub(crate) struct WindowSizeRef { + value: Arc>, + notifier: Arc, +} + +impl WindowSizeRef { + pub(crate) fn new(initial: u32) -> Self { + let notifier = Arc::new(Notify::new()); + Self { + value: Arc::new(Mutex::new(initial)), + notifier, + } + } + + pub(crate) async fn update(&self, value: u32) { + *self.value.lock().await = value; + self.notifier.notify_one(); + } + + pub(crate) fn subscribe(&self) -> Arc { + Arc::clone(&self.notifier) + } +} + +/// A handle to the reading part of a session channel. +/// +/// Allows you to read from a channel without borrowing the session +pub struct ChannelReadHalf { + pub(crate) receiver: Receiver, +} + +impl std::fmt::Debug for ChannelReadHalf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ChannelReadHalf").finish() + } +} + +impl ChannelReadHalf { + /// Awaits an incoming [`ChannelMsg`], this method returns [`None`] if the channel has been closed. + pub async fn wait(&mut self) -> Option { + self.receiver.recv().await + } + + /// Make a reader for the [`Channel`] to receive [`ChannelMsg::Data`] + /// through the `AsyncRead` trait. + pub fn make_reader(&mut self) -> impl AsyncRead + '_ { + self.make_reader_ext(None) + } + + /// Make a reader for the [`Channel`] to receive [`ChannelMsg::Data`] or [`ChannelMsg::ExtendedData`] + /// depending on the `ext` parameter, through the `AsyncRead` trait. + pub fn make_reader_ext(&mut self, ext: Option) -> impl AsyncRead + '_ { + io::ChannelRx::new(self, ext) + } +} + +/// A handle to the writing part of a session channel. +/// +/// Allows you to write to a channel without borrowing the session +pub struct ChannelWriteHalf> { + pub(crate) id: ChannelId, + pub(crate) sender: Sender, + pub(crate) max_packet_size: u32, + pub(crate) window_size: WindowSizeRef, +} + +impl> std::fmt::Debug for ChannelWriteHalf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ChannelWriteHalf") + .field("id", &self.id) + .finish() + } +} + +impl + Send + Sync + 'static> ChannelWriteHalf { + /// Returns the min between the maximum packet size and the + /// remaining window size in the channel. + pub async fn writable_packet_size(&self) -> usize { + self.max_packet_size + .min(*self.window_size.value.lock().await) as usize + } + + pub fn id(&self) -> ChannelId { + self.id + } + + /// Request a pseudo-terminal with the given characteristics. + #[allow(clippy::too_many_arguments)] // length checked + pub async fn request_pty( + &self, + want_reply: bool, + term: &str, + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + terminal_modes: &[(Pty, u32)], + ) -> Result<(), Error> { + self.send_msg(ChannelMsg::RequestPty { + want_reply, + term: term.to_string(), + col_width, + row_height, + pix_width, + pix_height, + terminal_modes: terminal_modes.to_vec(), + }) + .await + } + + /// Request a remote shell. + pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> { + self.send_msg(ChannelMsg::RequestShell { want_reply }).await + } + + /// Execute a remote program (will be passed to a shell). This can + /// be used to implement scp (by calling a remote scp and + /// tunneling to its standard input). + pub async fn exec>>(&self, want_reply: bool, command: A) -> Result<(), Error> { + self.send_msg(ChannelMsg::Exec { + want_reply, + command: command.into(), + }) + .await + } + + /// Signal a remote process. + pub async fn signal(&self, signal: Sig) -> Result<(), Error> { + self.send_msg(ChannelMsg::Signal { signal }).await + } + + /// Request the start of a subsystem with the given name. + pub async fn request_subsystem>( + &self, + want_reply: bool, + name: A, + ) -> Result<(), Error> { + self.send_msg(ChannelMsg::RequestSubsystem { + want_reply, + name: name.into(), + }) + .await + } + + /// Request X11 forwarding through an already opened X11 + /// channel. See + /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.3.1) + /// for security issues related to cookies. + pub async fn request_x11, B: Into>( + &self, + want_reply: bool, + single_connection: bool, + x11_authentication_protocol: A, + x11_authentication_cookie: B, + x11_screen_number: u32, + ) -> Result<(), Error> { + self.send_msg(ChannelMsg::RequestX11 { + want_reply, + single_connection, + x11_authentication_protocol: x11_authentication_protocol.into(), + x11_authentication_cookie: x11_authentication_cookie.into(), + x11_screen_number, + }) + .await + } + + /// Set a remote environment variable. + pub async fn set_env, B: Into>( + &self, + want_reply: bool, + variable_name: A, + variable_value: B, + ) -> Result<(), Error> { + self.send_msg(ChannelMsg::SetEnv { + want_reply, + variable_name: variable_name.into(), + variable_value: variable_value.into(), + }) + .await + } + + /// Inform the server that our window size has changed. + pub async fn window_change( + &self, + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + ) -> Result<(), Error> { + self.send_msg(ChannelMsg::WindowChange { + col_width, + row_height, + pix_width, + pix_height, + }) + .await + } + + /// Inform the server that we will accept agent forwarding channels + pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> { + self.send_msg(ChannelMsg::AgentForward { want_reply }).await + } + + /// Send data to a channel. + pub async fn data(&self, data: R) -> Result<(), Error> { + self.send_data(None, data).await + } + + /// Send data to a channel. The number of bytes added to the + /// "sending pipeline" (to be processed by the event loop) is + /// returned. + pub async fn extended_data( + &self, + ext: u32, + data: R, + ) -> Result<(), Error> { + self.send_data(Some(ext), data).await + } + + async fn send_data( + &self, + ext: Option, + mut data: R, + ) -> Result<(), Error> { + let mut tx = self.make_writer_ext(ext); + + tokio::io::copy(&mut data, &mut tx).await?; + + Ok(()) + } + + pub async fn eof(&self) -> Result<(), Error> { + self.send_msg(ChannelMsg::Eof).await + } + + pub async fn exit_status(&self, exit_status: u32) -> Result<(), Error> { + self.send_msg(ChannelMsg::ExitStatus { exit_status }).await + } + + /// Request that the channel be closed. + pub async fn close(&self) -> Result<(), Error> { + self.send_msg(ChannelMsg::Close).await + } + + async fn send_msg(&self, msg: ChannelMsg) -> Result<(), Error> { + self.sender + .send((self.id, msg).into()) + .await + .map_err(|_| Error::SendError) + } + + /// Make a writer for the [`Channel`] to send [`ChannelMsg::Data`] + /// through the `AsyncWrite` trait. + pub fn make_writer(&self) -> impl AsyncWrite { + self.make_writer_ext(None) + } + + /// Make a writer for the [`Channel`] to send [`ChannelMsg::Data`] or [`ChannelMsg::ExtendedData`] + /// depending on the `ext` parameter, through the `AsyncWrite` trait. + pub fn make_writer_ext(&self, ext: Option) -> impl AsyncWrite { + io::ChannelTx::new( + self.sender.clone(), + self.id, + self.window_size.value.clone(), + self.window_size.subscribe(), + self.max_packet_size, + ext, + ) + } +} + +/// A handle to a session channel. +/// +/// Allows you to read and write from a channel without borrowing the session +pub struct Channel> { + pub(crate) read_half: ChannelReadHalf, + pub(crate) write_half: ChannelWriteHalf, +} + +impl> std::fmt::Debug for Channel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Channel") + .field("id", &self.write_half.id) + .finish() + } +} + +impl + Send + Sync + 'static> Channel { + pub(crate) fn new( + id: ChannelId, + sender: Sender, + max_packet_size: u32, + window_size: u32, + channel_buffer_size: usize, + ) -> (Self, ChannelRef) { + let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size); + let window_size = WindowSizeRef::new(window_size); + let read_half = ChannelReadHalf { receiver: rx }; + let write_half = ChannelWriteHalf { + id, + sender, + max_packet_size, + window_size: window_size.clone(), + }; + + ( + Self { + write_half, + read_half, + }, + ChannelRef { + sender: tx, + window_size, + }, + ) + } + + /// Returns the min between the maximum packet size and the + /// remaining window size in the channel. + pub async fn writable_packet_size(&self) -> usize { + self.write_half.writable_packet_size().await + } + + pub fn id(&self) -> ChannelId { + self.write_half.id() + } + + /// Split this [`Channel`] into a [`ChannelReadHalf`] and a [`ChannelWriteHalf`], which can be + /// used to read and write concurrently. + pub fn split(self) -> (ChannelReadHalf, ChannelWriteHalf) { + (self.read_half, self.write_half) + } + + /// Request a pseudo-terminal with the given characteristics. + #[allow(clippy::too_many_arguments)] // length checked + pub async fn request_pty( + &self, + want_reply: bool, + term: &str, + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + terminal_modes: &[(Pty, u32)], + ) -> Result<(), Error> { + self.write_half + .request_pty( + want_reply, + term, + col_width, + row_height, + pix_width, + pix_height, + terminal_modes, + ) + .await + } + + /// Request a remote shell. + pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> { + self.write_half.request_shell(want_reply).await + } + + /// Execute a remote program (will be passed to a shell). This can + /// be used to implement scp (by calling a remote scp and + /// tunneling to its standard input). + pub async fn exec>>(&self, want_reply: bool, command: A) -> Result<(), Error> { + self.write_half.exec(want_reply, command).await + } + + /// Signal a remote process. + pub async fn signal(&self, signal: Sig) -> Result<(), Error> { + self.write_half.signal(signal).await + } + + /// Get a `FnOnce` that can be used to send a signal through this channel + pub fn get_signal_sender( + &self, + ) -> impl FnOnce(Sig) -> Pin> + std::marker::Send>> + { + let sender = self.write_half.sender.clone(); + let id = self.write_half.id; + + move |signal| { + async move { + sender + .send((id, ChannelMsg::Signal { signal }).into()) + .await + .map_err(|_| Error::SendError)?; + + Ok(()) + } + .boxed() + } + } + + /// Request the start of a subsystem with the given name. + pub async fn request_subsystem>( + &self, + want_reply: bool, + name: A, + ) -> Result<(), Error> { + self.write_half.request_subsystem(want_reply, name).await + } + + /// Request X11 forwarding through an already opened X11 + /// channel. See + /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.3.1) + /// for security issues related to cookies. + pub async fn request_x11, B: Into>( + &self, + want_reply: bool, + single_connection: bool, + x11_authentication_protocol: A, + x11_authentication_cookie: B, + x11_screen_number: u32, + ) -> Result<(), Error> { + self.write_half + .request_x11( + want_reply, + single_connection, + x11_authentication_protocol, + x11_authentication_cookie, + x11_screen_number, + ) + .await + } + + /// Set a remote environment variable. + pub async fn set_env, B: Into>( + &self, + want_reply: bool, + variable_name: A, + variable_value: B, + ) -> Result<(), Error> { + self.write_half + .set_env(want_reply, variable_name, variable_value) + .await + } + + /// Inform the server that our window size has changed. + pub async fn window_change( + &self, + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + ) -> Result<(), Error> { + self.write_half + .window_change(col_width, row_height, pix_width, pix_height) + .await + } + + /// Inform the server that we will accept agent forwarding channels + pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> { + self.write_half.agent_forward(want_reply).await + } + + /// Send data to a channel. + pub async fn data(&self, data: R) -> Result<(), Error> { + self.write_half.data(data).await + } + + /// Send data to a channel. The number of bytes added to the + /// "sending pipeline" (to be processed by the event loop) is + /// returned. + pub async fn extended_data( + &self, + ext: u32, + data: R, + ) -> Result<(), Error> { + self.write_half.extended_data(ext, data).await + } + + pub async fn eof(&self) -> Result<(), Error> { + self.write_half.eof().await + } + + pub async fn exit_status(&self, exit_status: u32) -> Result<(), Error> { + self.write_half.exit_status(exit_status).await + } + + /// Request that the channel be closed. + pub async fn close(&self) -> Result<(), Error> { + self.write_half.close().await + } + + /// Awaits an incoming [`ChannelMsg`], this method returns [`None`] if the channel has been closed. + pub async fn wait(&mut self) -> Option { + self.read_half.wait().await + } + + /// Consume the [`Channel`] to produce a bidirectionnal stream, + /// sending and receiving [`ChannelMsg::Data`] as `AsyncRead` + `AsyncWrite`. + pub fn into_stream(self) -> ChannelStream { + ChannelStream::new( + io::ChannelTx::new( + self.write_half.sender.clone(), + self.write_half.id, + self.write_half.window_size.value.clone(), + self.write_half.window_size.subscribe(), + self.write_half.max_packet_size, + None, + ), + io::ChannelRx::new(io::ChannelCloseOnDrop(self), None), + ) + } + + /// Make a reader for the [`Channel`] to receive [`ChannelMsg::Data`] + /// through the `AsyncRead` trait. + pub fn make_reader(&mut self) -> impl AsyncRead + '_ { + self.read_half.make_reader() + } + + /// Make a reader for the [`Channel`] to receive [`ChannelMsg::Data`] or [`ChannelMsg::ExtendedData`] + /// depending on the `ext` parameter, through the `AsyncRead` trait. + pub fn make_reader_ext(&mut self, ext: Option) -> impl AsyncRead + '_ { + self.read_half.make_reader_ext(ext) + } + + /// Make a writer for the [`Channel`] to send [`ChannelMsg::Data`] + /// through the `AsyncWrite` trait. + pub fn make_writer(&self) -> impl AsyncWrite { + self.write_half.make_writer() + } + + /// Make a writer for the [`Channel`] to send [`ChannelMsg::Data`] or [`ChannelMsg::ExtendedData`] + /// depending on the `ext` parameter, through the `AsyncWrite` trait. + pub fn make_writer_ext(&self, ext: Option) -> impl AsyncWrite { + self.write_half.make_writer_ext(ext) + } +} diff --git a/russh/src/cipher/benchmark.rs b/russh/src/cipher/benchmark.rs new file mode 100644 index 00000000..115b9a60 --- /dev/null +++ b/russh/src/cipher/benchmark.rs @@ -0,0 +1,47 @@ +#![allow(clippy::unwrap_used)] +use criterion::*; +use rand::RngCore; + +pub fn bench(c: &mut Criterion) { + let mut rand_generator = black_box(rand::rngs::OsRng {}); + + let mut packet_length = black_box(vec![0u8; 4]); + + for cipher_name in [super::CHACHA20_POLY1305, super::AES_256_GCM] { + let cipher = super::CIPHERS.get(&cipher_name).unwrap(); + + let mut key = vec![0; cipher.key_len()]; + rand_generator.try_fill_bytes(&mut key).unwrap(); + let mut nonce = vec![0; cipher.nonce_len()]; + rand_generator.try_fill_bytes(&mut nonce).unwrap(); + + let mut sk = cipher.make_sealing_key(&key, &nonce, &[], &crate::mac::_NONE); + let mut ok = cipher.make_opening_key(&key, &nonce, &[], &crate::mac::_NONE); + + let mut group = c.benchmark_group(format!("Cipher: {}", cipher_name.0)); + for size in [100usize, 1000, 10000] { + let iterations = 10000 / size; + + group.throughput(Throughput::Bytes(size as u64)); + group.bench_function(format!("Block size: {size}"), |b| { + b.iter_with_setup( + || { + let mut in_out = black_box(vec![0u8; size]); + rand_generator.try_fill_bytes(&mut in_out).unwrap(); + rand_generator.try_fill_bytes(&mut packet_length).unwrap(); + in_out + }, + |mut in_out| { + for _ in 0..iterations { + let len = in_out.len(); + let (data, tag) = in_out.split_at_mut(len - sk.tag_len()); + sk.seal(0, data, tag); + ok.open(0, &mut in_out).unwrap(); + } + }, + ); + }); + } + group.finish(); + } +} diff --git a/russh/src/cipher/block.rs b/russh/src/cipher/block.rs index ccd6a4de..f30c36c8 100644 --- a/russh/src/cipher/block.rs +++ b/russh/src/cipher/block.rs @@ -11,6 +11,7 @@ // limitations under the License. // +use std::convert::TryInto; use std::marker::PhantomData; use aes::cipher::{IvSizeUser, KeyIvInit, KeySizeUser, StreamCipher}; @@ -21,9 +22,9 @@ use super::super::Error; use super::PACKET_LENGTH_LEN; use crate::mac::{Mac, MacAlgorithm}; -pub struct SshBlockCipher(pub PhantomData); +pub struct SshBlockCipher(pub PhantomData); -impl super::Cipher +impl super::Cipher for SshBlockCipher { fn key_len(&self) -> usize { @@ -73,29 +74,44 @@ impl su } } -pub struct OpeningKey { - cipher: C, - mac: Box, +pub struct OpeningKey { + pub(crate) cipher: C, + pub(crate) mac: Box, } -pub struct SealingKey { - cipher: C, - mac: Box, +pub struct SealingKey { + pub(crate) cipher: C, + pub(crate) mac: Box, } -impl super::OpeningKey for OpeningKey { +impl super::OpeningKey for OpeningKey { + fn packet_length_to_read_for_block_length(&self) -> usize { + 16 + } + fn decrypt_packet_length( &self, _sequence_number: u32, - mut encrypted_packet_length: [u8; 4], + encrypted_packet_length: &[u8], ) -> [u8; 4] { + let mut first_block = [0u8; 16]; + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::indexing_slicing)] + first_block.copy_from_slice(&encrypted_packet_length[..16]); + if self.mac.is_etm() { - encrypted_packet_length + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + encrypted_packet_length[..4].try_into().unwrap() } else { // Work around uncloneable Aes<> let mut cipher: C = unsafe { std::ptr::read(&self.cipher as *const C) }; - cipher.apply_keystream(&mut encrypted_packet_length); - encrypted_packet_length + + cipher.decrypt_data(&mut first_block); + + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + first_block[..4].try_into().unwrap() } } @@ -106,9 +122,10 @@ impl super::OpeningKey for OpeningKe fn open<'a>( &mut self, sequence_number: u32, - ciphertext_in_plaintext_out: &'a mut [u8], - tag: &[u8], + ciphertext_and_tag: &'a mut [u8], ) -> Result<&'a [u8], Error> { + let ciphertext_len = ciphertext_and_tag.len() - self.tag_len(); + let (ciphertext_in_plaintext_out, tag) = ciphertext_and_tag.split_at_mut(ciphertext_len); if self.mac.is_etm() { if !self .mac @@ -118,9 +135,9 @@ impl super::OpeningKey for OpeningKe } #[allow(clippy::indexing_slicing)] self.cipher - .apply_keystream(&mut ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]); + .decrypt_data(&mut ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]); } else { - self.cipher.apply_keystream(ciphertext_in_plaintext_out); + self.cipher.decrypt_data(ciphertext_in_plaintext_out); if !self .mac @@ -129,11 +146,13 @@ impl super::OpeningKey for OpeningKe return Err(Error::PacketAuth); } } - Ok(ciphertext_in_plaintext_out) + + #[allow(clippy::indexing_slicing)] + Ok(&ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]) } } -impl super::SealingKey for SealingKey { +impl super::SealingKey for SealingKey { fn padding_length(&self, payload: &[u8]) -> usize { let block_size = 16; @@ -174,13 +193,28 @@ impl super::SealingKey for SealingKe if self.mac.is_etm() { #[allow(clippy::indexing_slicing)] self.cipher - .apply_keystream(&mut plaintext_in_ciphertext_out[PACKET_LENGTH_LEN..]); + .encrypt_data(&mut plaintext_in_ciphertext_out[PACKET_LENGTH_LEN..]); self.mac .compute(sequence_number, plaintext_in_ciphertext_out, tag_out); } else { self.mac .compute(sequence_number, plaintext_in_ciphertext_out, tag_out); - self.cipher.apply_keystream(plaintext_in_ciphertext_out); + self.cipher.encrypt_data(plaintext_in_ciphertext_out); } } } + +pub trait BlockStreamCipher { + fn encrypt_data(&mut self, data: &mut [u8]); + fn decrypt_data(&mut self, data: &mut [u8]); +} + +impl BlockStreamCipher for T { + fn encrypt_data(&mut self, data: &mut [u8]) { + self.apply_keystream(data); + } + + fn decrypt_data(&mut self, data: &mut [u8]) { + self.apply_keystream(data); + } +} diff --git a/russh/src/cipher/cbc.rs b/russh/src/cipher/cbc.rs new file mode 100644 index 00000000..87a0c66a --- /dev/null +++ b/russh/src/cipher/cbc.rs @@ -0,0 +1,53 @@ +use aes::cipher::{ + BlockCipher, BlockDecrypt, BlockDecryptMut, BlockEncrypt, BlockEncryptMut, InnerIvInit, Iv, + IvSizeUser, +}; +use cbc::{Decryptor, Encryptor}; +use digest::crypto_common::InnerUser; +use generic_array::GenericArray; + +use super::block::BlockStreamCipher; + +pub struct CbcWrapper { + encryptor: Encryptor, + decryptor: Decryptor, +} + +impl InnerUser for CbcWrapper { + type Inner = C; +} + +impl IvSizeUser for CbcWrapper { + type IvSize = C::BlockSize; +} + +impl BlockStreamCipher for CbcWrapper { + fn encrypt_data(&mut self, data: &mut [u8]) { + for chunk in data.chunks_exact_mut(C::block_size()) { + let mut block: GenericArray = GenericArray::clone_from_slice(chunk); + self.encryptor.encrypt_block_mut(&mut block); + chunk.clone_from_slice(&block); + } + } + + fn decrypt_data(&mut self, data: &mut [u8]) { + for chunk in data.chunks_exact_mut(C::block_size()) { + let mut block = GenericArray::clone_from_slice(chunk); + self.decryptor.decrypt_block_mut(&mut block); + chunk.clone_from_slice(&block); + } + } +} + +impl InnerIvInit for CbcWrapper +where + C: BlockEncryptMut + BlockCipher, +{ + #[inline] + fn inner_iv_init(cipher: C, iv: &Iv) -> Self { + Self { + encryptor: Encryptor::inner_iv_init(cipher.clone(), iv), + decryptor: Decryptor::inner_iv_init(cipher, iv), + } + } +} diff --git a/russh/src/cipher/chacha20poly1305.rs b/russh/src/cipher/chacha20poly1305.rs index cab3eece..8e288b73 100644 --- a/russh/src/cipher/chacha20poly1305.rs +++ b/russh/src/cipher/chacha20poly1305.rs @@ -15,33 +15,21 @@ // http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?annotate=HEAD -use aes::cipher::{BlockSizeUser, StreamCipherSeek}; -use byteorder::{BigEndian, ByteOrder}; -use chacha20::cipher::{KeyInit, KeyIvInit, StreamCipher}; -use chacha20::{ChaCha20Legacy, ChaCha20LegacyCore}; -use generic_array::typenum::{Unsigned, U16, U32, U8}; -use generic_array::GenericArray; -use poly1305::Poly1305; -use subtle::ConstantTimeEq; +#[cfg(feature = "aws-lc-rs")] +use aws_lc_rs::aead::chacha20_poly1305_openssh; +#[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] +use ring::aead::chacha20_poly1305_openssh; use super::super::Error; -use crate::cipher::PACKET_LENGTH_LEN; use crate::mac::MacAlgorithm; pub struct SshChacha20Poly1305Cipher {} -type KeyLength = U32; -type NonceLength = U8; -type TagLength = U16; -type Key = GenericArray; -type Nonce = GenericArray; - impl super::Cipher for SshChacha20Poly1305Cipher { fn key_len(&self) -> usize { - KeyLength::to_usize() * 2 + chacha20_poly1305_openssh::KEY_LEN } - #[allow(clippy::indexing_slicing)] // length checked fn make_opening_key( &self, k: &[u8], @@ -49,14 +37,12 @@ impl super::Cipher for SshChacha20Poly1305Cipher { _: &[u8], _: &dyn MacAlgorithm, ) -> Box { - let mut k1 = Key::default(); - let mut k2 = Key::default(); - k1.clone_from_slice(&k[KeyLength::to_usize()..]); - k2.clone_from_slice(&k[..KeyLength::to_usize()]); - Box::new(OpeningKey { k1, k2 }) + Box::new(OpeningKey(chacha20_poly1305_openssh::OpeningKey::new( + #[allow(clippy::unwrap_used)] + k.try_into().unwrap(), + ))) } - #[allow(clippy::indexing_slicing)] // length checked fn make_sealing_key( &self, k: &[u8], @@ -64,68 +50,50 @@ impl super::Cipher for SshChacha20Poly1305Cipher { _: &[u8], _: &dyn MacAlgorithm, ) -> Box { - let mut k1 = Key::default(); - let mut k2 = Key::default(); - k1.clone_from_slice(&k[KeyLength::to_usize()..]); - k2.clone_from_slice(&k[..KeyLength::to_usize()]); - Box::new(SealingKey { k1, k2 }) + Box::new(SealingKey(chacha20_poly1305_openssh::SealingKey::new( + #[allow(clippy::unwrap_used)] + k.try_into().unwrap(), + ))) } } -pub struct OpeningKey { - k1: Key, - k2: Key, -} +pub struct OpeningKey(chacha20_poly1305_openssh::OpeningKey); -pub struct SealingKey { - k1: Key, - k2: Key, -} - -#[allow(clippy::indexing_slicing)] // length checked -fn make_counter(sequence_number: u32) -> Nonce { - let mut nonce = Nonce::default(); - let i0 = NonceLength::to_usize() - 4; - BigEndian::write_u32(&mut nonce[i0..], sequence_number); - nonce -} +pub struct SealingKey(chacha20_poly1305_openssh::SealingKey); impl super::OpeningKey for OpeningKey { fn decrypt_packet_length( &self, sequence_number: u32, - mut encrypted_packet_length: [u8; 4], + encrypted_packet_length: &[u8], ) -> [u8; 4] { - let nonce = make_counter(sequence_number); - let mut cipher = ChaCha20Legacy::new(&self.k1, &nonce); - cipher.apply_keystream(&mut encrypted_packet_length); - encrypted_packet_length + self.0.decrypt_packet_length( + sequence_number, + #[allow(clippy::unwrap_used)] + encrypted_packet_length.try_into().unwrap(), + ) } fn tag_len(&self) -> usize { - TagLength::to_usize() + chacha20_poly1305_openssh::TAG_LEN } - #[allow(clippy::indexing_slicing)] // lengths checked fn open<'a>( &mut self, sequence_number: u32, - ciphertext_in_plaintext_out: &'a mut [u8], - tag: &[u8], + ciphertext_and_tag: &'a mut [u8], ) -> Result<&'a [u8], Error> { - let nonce = make_counter(sequence_number); - let expected_tag = compute_poly1305(&nonce, &self.k2, ciphertext_in_plaintext_out); - - if !bool::from(expected_tag.ct_eq(tag)) { - return Err(Error::DecryptionError); - } - - let mut cipher = ChaCha20Legacy::new(&self.k2, &nonce); - - cipher.seek(::BlockSize::to_usize()); - cipher.apply_keystream(&mut ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]); - - Ok(&ciphertext_in_plaintext_out[PACKET_LENGTH_LEN..]) + let ciphertext_len = ciphertext_and_tag.len() - self.tag_len(); + let (ciphertext_in_plaintext_out, tag) = ciphertext_and_tag.split_at_mut(ciphertext_len); + + self.0 + .open_in_place( + sequence_number, + ciphertext_in_plaintext_out, + #[allow(clippy::unwrap_used)] + &tag.try_into().unwrap(), + ) + .map_err(|_| Error::DecryptionError) } } @@ -156,7 +124,7 @@ impl super::SealingKey for SealingKey { } fn tag_len(&self) -> usize { - TagLength::to_usize() + chacha20_poly1305_openssh::TAG_LEN } fn seal( @@ -165,31 +133,11 @@ impl super::SealingKey for SealingKey { plaintext_in_ciphertext_out: &mut [u8], tag: &mut [u8], ) { - let nonce = make_counter(sequence_number); - - let mut cipher = ChaCha20Legacy::new(&self.k1, &nonce); - #[allow(clippy::indexing_slicing)] // length checked - cipher.apply_keystream(&mut plaintext_in_ciphertext_out[..PACKET_LENGTH_LEN]); - - // -- - let mut cipher = ChaCha20Legacy::new(&self.k2, &nonce); - - cipher.seek(::BlockSize::to_usize()); - #[allow(clippy::indexing_slicing, clippy::unwrap_used)] - cipher.apply_keystream(&mut plaintext_in_ciphertext_out[PACKET_LENGTH_LEN..]); - - // -- - - tag.copy_from_slice( - compute_poly1305(&nonce, &self.k2, plaintext_in_ciphertext_out).as_slice(), + self.0.seal_in_place( + sequence_number, + plaintext_in_ciphertext_out, + #[allow(clippy::unwrap_used)] + tag.try_into().unwrap(), ); } } - -fn compute_poly1305(nonce: &Nonce, key: &Key, data: &[u8]) -> poly1305::Tag { - let mut cipher = ChaCha20Legacy::new(key, nonce); - let mut poly_key = GenericArray::::default(); - cipher.apply_keystream(&mut poly_key); - - Poly1305::new(&poly_key).compute_unpadded(data) -} diff --git a/russh/src/cipher/clear.rs b/russh/src/cipher/clear.rs index ddd552db..955a4e80 100644 --- a/russh/src/cipher/clear.rs +++ b/russh/src/cipher/clear.rs @@ -13,6 +13,8 @@ // limitations under the License. // +use std::convert::TryInto; + use crate::mac::MacAlgorithm; use crate::Error; @@ -48,8 +50,10 @@ impl super::Cipher for Clear { } impl super::OpeningKey for Key { - fn decrypt_packet_length(&self, _seqn: u32, packet_length: [u8; 4]) -> [u8; 4] { - packet_length + fn decrypt_packet_length(&self, _seqn: u32, packet_length: &[u8]) -> [u8; 4] { + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + packet_length.try_into().unwrap() } fn tag_len(&self) -> usize { @@ -59,12 +63,10 @@ impl super::OpeningKey for Key { fn open<'a>( &mut self, _seqn: u32, - ciphertext_in_plaintext_out: &'a mut [u8], - tag: &[u8], + ciphertext_and_tag: &'a mut [u8], ) -> Result<&'a [u8], Error> { - debug_assert_eq!(tag.len(), 0); // self.tag_len()); #[allow(clippy::indexing_slicing)] // length known - Ok(&ciphertext_in_plaintext_out[4..]) + Ok(&ciphertext_and_tag[4..]) } } diff --git a/russh/src/cipher/gcm.rs b/russh/src/cipher/gcm.rs index f737716c..9855133c 100644 --- a/russh/src/cipher/gcm.rs +++ b/russh/src/cipher/gcm.rs @@ -15,28 +15,38 @@ // http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?annotate=HEAD -use aes_gcm::{AeadCore, AeadInPlace, Aes256Gcm, KeyInit, KeySizeUser}; -use byteorder::{BigEndian, ByteOrder}; -use digest::typenum::Unsigned; -use generic_array::GenericArray; +use std::convert::TryInto; + +#[cfg(feature = "aws-lc-rs")] +use aws_lc_rs::{ + aead::{ + Aad, Algorithm, BoundKey, Nonce as AeadNonce, NonceSequence, OpeningKey as AeadOpeningKey, + SealingKey as AeadSealingKey, UnboundKey, NONCE_LEN, + }, + error::Unspecified, +}; use rand::RngCore; +#[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] +use ring::{ + aead::{ + Aad, Algorithm, BoundKey, Nonce as AeadNonce, NonceSequence, OpeningKey as AeadOpeningKey, + SealingKey as AeadSealingKey, UnboundKey, NONCE_LEN, + }, + error::Unspecified, +}; use super::super::Error; use crate::mac::MacAlgorithm; -pub struct GcmCipher {} - -type KeySize = ::KeySize; -type NonceSize = ::NonceSize; -type TagSize = ::TagSize; +pub struct GcmCipher(pub(crate) &'static Algorithm); impl super::Cipher for GcmCipher { fn key_len(&self) -> usize { - Aes256Gcm::key_size() + self.0.key_len() } fn nonce_len(&self) -> usize { - GenericArray::::default().len() + self.0.nonce_len() } fn make_opening_key( @@ -46,14 +56,11 @@ impl super::Cipher for GcmCipher { _: &[u8], _: &dyn MacAlgorithm, ) -> Box { - let mut key = GenericArray::::default(); - key.clone_from_slice(k); - let mut nonce = GenericArray::::default(); - nonce.clone_from_slice(n); - Box::new(OpeningKey { - nonce, - cipher: Aes256Gcm::new(&key), - }) + #[allow(clippy::unwrap_used)] + Box::new(OpeningKey(AeadOpeningKey::new( + UnboundKey::new(self.0, k).unwrap(), + Nonce(n.try_into().unwrap()), + ))) } fn make_sealing_key( @@ -63,100 +70,76 @@ impl super::Cipher for GcmCipher { _: &[u8], _: &dyn MacAlgorithm, ) -> Box { - let mut key = GenericArray::::default(); - key.clone_from_slice(k); - let mut nonce = GenericArray::::default(); - nonce.clone_from_slice(n); - Box::new(SealingKey { - nonce, - cipher: Aes256Gcm::new(&key), - }) + #[allow(clippy::unwrap_used)] + Box::new(SealingKey(AeadSealingKey::new( + UnboundKey::new(self.0, k).unwrap(), + Nonce(n.try_into().unwrap()), + ))) } } -pub struct OpeningKey { - nonce: GenericArray, - cipher: Aes256Gcm, -} +pub struct OpeningKey(AeadOpeningKey); -pub struct SealingKey { - nonce: GenericArray, - cipher: Aes256Gcm, -} +pub struct SealingKey(AeadSealingKey); -const GCM_COUNTER_OFFSET: u64 = 3; - -fn make_nonce( - nonce: &GenericArray, - sequence_number: u32, -) -> GenericArray { - let mut new_nonce = GenericArray::::default(); - new_nonce.clone_from_slice(nonce); - // Increment the nonce - let i0 = new_nonce.len() - 8; - - #[allow(clippy::indexing_slicing)] // length checked - let ctr = BigEndian::read_u64(&new_nonce[i0..]); - - // GCM requires the counter to start from 1 - #[allow(clippy::indexing_slicing)] // length checked - BigEndian::write_u64( - &mut new_nonce[i0..], - ctr + sequence_number as u64 - GCM_COUNTER_OFFSET, - ); - new_nonce +struct Nonce([u8; NONCE_LEN]); + +impl NonceSequence for Nonce { + fn advance(&mut self) -> Result { + let mut previous_nonce = [0u8; NONCE_LEN]; + #[allow(clippy::indexing_slicing)] // length checked + previous_nonce.clone_from_slice(&self.0[..]); + let mut carry = 1; + #[allow(clippy::indexing_slicing)] // length checked + for i in (0..NONCE_LEN).rev() { + let n = self.0[i] as u16 + carry; + self.0[i] = n as u8; + carry = n >> 8; + } + Ok(AeadNonce::assume_unique_for_key(previous_nonce)) + } } -impl super::OpeningKey for OpeningKey { +impl super::OpeningKey for OpeningKey { fn decrypt_packet_length( &self, _sequence_number: u32, - encrypted_packet_length: [u8; 4], + encrypted_packet_length: &[u8], ) -> [u8; 4] { - encrypted_packet_length + // Fine because of self.packet_length_to_read_for_block_length() + #[allow(clippy::unwrap_used, clippy::indexing_slicing)] + encrypted_packet_length.try_into().unwrap() } fn tag_len(&self) -> usize { - TagSize::to_usize() + self.0.algorithm().tag_len() } fn open<'a>( &mut self, - sequence_number: u32, - ciphertext_in_plaintext_out: &'a mut [u8], - tag: &[u8], + _sequence_number: u32, + ciphertext_and_tag: &'a mut [u8], ) -> Result<&'a [u8], Error> { // Packet length is sent unencrypted let mut packet_length = [0; super::PACKET_LENGTH_LEN]; #[allow(clippy::indexing_slicing)] // length checked - packet_length.clone_from_slice(&ciphertext_in_plaintext_out[..super::PACKET_LENGTH_LEN]); - - let mut buffer = vec![0; ciphertext_in_plaintext_out.len() - super::PACKET_LENGTH_LEN]; - - #[allow(clippy::indexing_slicing)] // length checked - buffer.copy_from_slice(&ciphertext_in_plaintext_out[super::PACKET_LENGTH_LEN..]); - - let nonce = make_nonce(&self.nonce, sequence_number); - - let mut tag_buf = GenericArray::::default(); - tag_buf.clone_from_slice(tag); - - #[allow(clippy::indexing_slicing)] - self.cipher - .decrypt_in_place_detached( - &nonce, - &packet_length, - &mut ciphertext_in_plaintext_out[super::PACKET_LENGTH_LEN..], - &tag_buf, + packet_length.clone_from_slice(&ciphertext_and_tag[..super::PACKET_LENGTH_LEN]); + + let buf = self + .0 + .open_in_place( + Aad::from(&packet_length), + #[allow(clippy::indexing_slicing)] // length checked + &mut ciphertext_and_tag[super::PACKET_LENGTH_LEN..], ) .map_err(|_| Error::DecryptionError)?; - Ok(ciphertext_in_plaintext_out) + Ok(buf) } } -impl super::SealingKey for SealingKey { +impl super::SealingKey for SealingKey { fn padding_length(&self, payload: &[u8]) -> usize { let block_size = 16; let extra_len = super::PACKET_LENGTH_LEN + super::PADDING_LENGTH_LEN; @@ -177,12 +160,12 @@ impl super::SealingKey for SealingKey { } fn tag_len(&self) -> usize { - TagSize::to_usize() + self.0.algorithm().tag_len() } fn seal( &mut self, - sequence_number: u32, + _sequence_number: u32, plaintext_in_ciphertext_out: &mut [u8], tag: &mut [u8], ) { @@ -191,18 +174,16 @@ impl super::SealingKey for SealingKey { #[allow(clippy::indexing_slicing)] // length checked packet_length.clone_from_slice(&plaintext_in_ciphertext_out[..super::PACKET_LENGTH_LEN]); - let nonce = make_nonce(&self.nonce, sequence_number); - - #[allow(clippy::indexing_slicing, clippy::unwrap_used)] + #[allow(clippy::unwrap_used)] let tag_out = self - .cipher - .encrypt_in_place_detached( - &nonce, - &packet_length, + .0 + .seal_in_place_separate_tag( + Aad::from(&packet_length), + #[allow(clippy::indexing_slicing)] &mut plaintext_in_ciphertext_out[super::PACKET_LENGTH_LEN..], ) .unwrap(); - tag.clone_from_slice(&tag_out) + tag.clone_from_slice(tag_out.as_ref()); } } diff --git a/russh/src/cipher/mod.rs b/russh/src/cipher/mod.rs index 1251253d..d24542aa 100644 --- a/russh/src/cipher/mod.rs +++ b/russh/src/cipher/mod.rs @@ -14,26 +14,37 @@ //! //! This module exports cipher names for use with [Preferred]. +use std::borrow::Borrow; use std::collections::HashMap; +use std::convert::TryFrom; use std::fmt::Debug; use std::marker::PhantomData; use std::num::Wrapping; use aes::{Aes128, Aes192, Aes256}; +#[cfg(feature = "aws-lc-rs")] +use aws_lc_rs::aead::{AES_128_GCM as ALGORITHM_AES_128_GCM, AES_256_GCM as ALGORITHM_AES_256_GCM}; use byteorder::{BigEndian, ByteOrder}; use ctr::Ctr128BE; +use delegate::delegate; +use log::trace; use once_cell::sync::Lazy; +#[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] +use ring::aead::{AES_128_GCM as ALGORITHM_AES_128_GCM, AES_256_GCM as ALGORITHM_AES_256_GCM}; +use ssh_encoding::Encode; use tokio::io::{AsyncRead, AsyncReadExt}; -use log::debug; +use self::cbc::CbcWrapper; use crate::mac::MacAlgorithm; use crate::sshbuffer::SSHBuffer; use crate::Error; pub(crate) mod block; +pub(crate) mod cbc; pub(crate) mod chacha20poly1305; pub(crate) mod clear; pub(crate) mod gcm; + use block::SshBlockCipher; use chacha20poly1305::SshChacha20Poly1305Cipher; use clear::Clear; @@ -65,12 +76,23 @@ pub(crate) trait Cipher { /// `clear` pub const CLEAR: Name = Name("clear"); +/// `3des-cbc` +#[cfg(feature = "des")] +pub const TRIPLE_DES_CBC: Name = Name("3des-cbc"); /// `aes128-ctr` pub const AES_128_CTR: Name = Name("aes128-ctr"); /// `aes192-ctr` pub const AES_192_CTR: Name = Name("aes192-ctr"); +/// `aes128-cbc` +pub const AES_128_CBC: Name = Name("aes128-cbc"); +/// `aes192-cbc` +pub const AES_192_CBC: Name = Name("aes192-cbc"); +/// `aes256-cbc` +pub const AES_256_CBC: Name = Name("aes256-cbc"); /// `aes256-ctr` pub const AES_256_CTR: Name = Name("aes256-ctr"); +/// `aes128-gcm@openssh.com` +pub const AES_128_GCM: Name = Name("aes128-gcm@openssh.com"); /// `aes256-gcm@openssh.com` pub const AES_256_GCM: Name = Name("aes256-gcm@openssh.com"); /// `chacha20-poly1305@openssh.com` @@ -78,23 +100,52 @@ pub const CHACHA20_POLY1305: Name = Name("chacha20-poly1305@openssh.com"); /// `none` pub const NONE: Name = Name("none"); -static _CLEAR: Clear = Clear {}; +pub(crate) static _CLEAR: Clear = Clear {}; +#[cfg(feature = "des")] +static _3DES_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_128_CTR: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_192_CTR: SshBlockCipher> = SshBlockCipher(PhantomData); static _AES_256_CTR: SshBlockCipher> = SshBlockCipher(PhantomData); -static _AES_256_GCM: GcmCipher = GcmCipher {}; +static _AES_128_GCM: GcmCipher = GcmCipher(&ALGORITHM_AES_128_GCM); +static _AES_256_GCM: GcmCipher = GcmCipher(&ALGORITHM_AES_256_GCM); +static _AES_128_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); +static _AES_192_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); +static _AES_256_CBC: SshBlockCipher> = SshBlockCipher(PhantomData); static _CHACHA20_POLY1305: SshChacha20Poly1305Cipher = SshChacha20Poly1305Cipher {}; +pub static ALL_CIPHERS: &[&Name] = &[ + &CLEAR, + &NONE, + #[cfg(feature = "des")] + &TRIPLE_DES_CBC, + &AES_128_CTR, + &AES_192_CTR, + &AES_256_CTR, + &AES_128_GCM, + &AES_256_GCM, + &AES_128_CBC, + &AES_192_CBC, + &AES_256_CBC, + &CHACHA20_POLY1305, +]; + pub(crate) static CIPHERS: Lazy> = Lazy::new(|| { let mut h: HashMap<&'static Name, &(dyn Cipher + Send + Sync)> = HashMap::new(); h.insert(&CLEAR, &_CLEAR); h.insert(&NONE, &_CLEAR); + #[cfg(feature = "des")] + h.insert(&TRIPLE_DES_CBC, &_3DES_CBC); h.insert(&AES_128_CTR, &_AES_128_CTR); h.insert(&AES_192_CTR, &_AES_192_CTR); h.insert(&AES_256_CTR, &_AES_256_CTR); + h.insert(&AES_128_GCM, &_AES_128_GCM); h.insert(&AES_256_GCM, &_AES_256_GCM); + h.insert(&AES_128_CBC, &_AES_128_CBC); + h.insert(&AES_192_CBC, &_AES_192_CBC); + h.insert(&AES_256_CBC, &_AES_256_CBC); h.insert(&CHACHA20_POLY1305, &_CHACHA20_POLY1305); + assert_eq!(h.len(), ALL_CIPHERS.len()); h }); @@ -106,6 +157,26 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + +impl Borrow for &Name { + fn borrow(&self) -> &str { + self.0 + } +} + +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + CIPHERS.keys().find(|x| x.0 == s).map(|x| **x).ok_or(()) + } +} + pub(crate) struct CipherPair { pub local_to_remote: Box, pub remote_to_local: Box, @@ -118,16 +189,15 @@ impl Debug for CipherPair { } pub(crate) trait OpeningKey { - fn decrypt_packet_length(&self, seqn: u32, encrypted_packet_length: [u8; 4]) -> [u8; 4]; + fn packet_length_to_read_for_block_length(&self) -> usize { + 4 + } + + fn decrypt_packet_length(&self, seqn: u32, encrypted_packet_length: &[u8]) -> [u8; 4]; fn tag_len(&self) -> usize; - fn open<'a>( - &mut self, - seqn: u32, - ciphertext_in_plaintext_out: &'a mut [u8], - tag: &[u8], - ) -> Result<&'a [u8], Error>; + fn open<'a>(&mut self, seqn: u32, ciphertext_and_tag: &'a mut [u8]) -> Result<&'a [u8], Error>; } pub(crate) trait SealingKey { @@ -144,20 +214,21 @@ pub(crate) trait SealingKey { // // The variables `payload`, `packet_length` and `padding_length` refer // to the protocol fields of the same names. - debug!("writing, seqn = {:?}", buffer.seqn.0); + trace!("writing, seqn = {:?}", buffer.seqn.0); let padding_length = self.padding_length(payload); - debug!("padding length {:?}", padding_length); + trace!("padding length {:?}", padding_length); let packet_length = PADDING_LENGTH_LEN + payload.len() + padding_length; - debug!("packet_length {:?}", packet_length); + trace!("packet_length {:?}", packet_length); let offset = buffer.buffer.len(); // Maximum packet length: // https://tools.ietf.org/html/rfc4253#section-6.1 - assert!(packet_length <= std::u32::MAX as usize); - buffer.buffer.push_u32_be(packet_length as u32); + assert!(packet_length <= u32::MAX as usize); + #[allow(clippy::unwrap_used)] // length checked + (packet_length as u32).encode(&mut buffer.buffer).unwrap(); - assert!(padding_length <= std::u8::MAX as usize); + assert!(padding_length <= u8::MAX as usize); buffer.buffer.push(padding_length as u8); buffer.buffer.extend(payload); self.fill_padding(buffer.buffer.resize_mut(padding_length)); @@ -176,38 +247,47 @@ pub(crate) trait SealingKey { } } -pub(crate) async fn read<'a, R: AsyncRead + Unpin>( - stream: &'a mut R, - buffer: &'a mut SSHBuffer, - cipher: &'a mut (dyn OpeningKey + Send), +pub(crate) async fn read( + stream: &mut R, + buffer: &mut SSHBuffer, + cipher: &mut (dyn OpeningKey + Send), ) -> Result { if buffer.len == 0 { - let mut len = [0; 4]; + let mut len = vec![0; cipher.packet_length_to_read_for_block_length()]; + stream.read_exact(&mut len).await?; - debug!("reading, len = {:?}", len); + trace!("reading, len = {:?}", len); { let seqn = buffer.seqn.0; buffer.buffer.clear(); buffer.buffer.extend(&len); - debug!("reading, seqn = {:?}", seqn); - let len = cipher.decrypt_packet_length(seqn, len); - buffer.len = BigEndian::read_u32(&len) as usize + cipher.tag_len(); - debug!("reading, clear len = {:?}", buffer.len); + trace!("reading, seqn = {:?}", seqn); + let len = cipher.decrypt_packet_length(seqn, &len); + let len = BigEndian::read_u32(&len) as usize; + + if len > MAXIMUM_PACKET_LEN { + return Err(Error::PacketSize(len)); + } + + buffer.len = len + cipher.tag_len(); + trace!("reading, clear len = {:?}", buffer.len); } } buffer.buffer.resize(buffer.len + 4); - debug!("read_exact {:?}", buffer.len + 4); + trace!("read_exact {:?}", buffer.len + 4); + + let l = cipher.packet_length_to_read_for_block_length(); + #[allow(clippy::indexing_slicing)] // length checked - stream.read_exact(&mut buffer.buffer[4..]).await?; - debug!("read_exact done"); + stream.read_exact(&mut buffer.buffer[l..]).await?; + + trace!("read_exact done"); let seqn = buffer.seqn.0; - let ciphertext_len = buffer.buffer.len() - cipher.tag_len(); - let (ciphertext, tag) = buffer.buffer.split_at_mut(ciphertext_len); - let plaintext = cipher.open(seqn, ciphertext, tag)?; + let plaintext = cipher.open(seqn, &mut buffer.buffer)?; let padding_length = *plaintext.first().to_owned().unwrap_or(&0) as usize; - debug!("reading, padding_length {:?}", padding_length); + trace!("reading, padding_length {:?}", padding_length); let plaintext_end = plaintext .len() .checked_sub(padding_length) @@ -227,5 +307,9 @@ pub(crate) async fn read<'a, R: AsyncRead + Unpin>( pub(crate) const PACKET_LENGTH_LEN: usize = 4; const MINIMUM_PACKET_LEN: usize = 16; +const MAXIMUM_PACKET_LEN: usize = 256 * 1024; const PADDING_LENGTH_LEN: usize = 1; + +#[cfg(feature = "_bench")] +pub mod benchmark; diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index cc7df2b4..c6ac94d8 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -14,20 +14,25 @@ // use std::cell::RefCell; use std::convert::TryInto; +use std::ops::Deref; +use std::str::FromStr; +use bytes::Bytes; use log::{debug, error, info, trace, warn}; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::{Encoding, Reader}; -use russh_keys::key::parse_public_key; -use tokio::sync::mpsc::unbounded_channel; +use ssh_encoding::{Decode, Encode, Reader}; +use ssh_key::Algorithm; +use super::IncomingSshPacket; +use crate::auth::AuthRequest; +use crate::cert::PublicKeyOrCertificate; use crate::client::{Handler, Msg, Prompt, Reply, Session}; -use crate::key::PubKey; -use crate::negotiation::{Named, Select}; +use crate::helpers::{map_err, sign_with_hash_alg, AlgorithmExt, EncodedExt, NameList}; +use crate::keys::key::parse_public_key; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; -use crate::session::{Encrypted, EncryptedState, Kex, KexInit}; +use crate::session::{Encrypted, EncryptedState, GlobalRequestResponse}; use crate::{ - auth, msg, negotiation, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, Sig, + auth, msg, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, CryptoVec, Error, + MethodSet, Sig, }; thread_local! { @@ -36,121 +41,28 @@ thread_local! { impl Session { pub(crate) async fn client_read_encrypted( - mut self, - mut client: H, - buf: &[u8], - ) -> Result<(H, Self), H::Error> { + &mut self, + client: &mut H, + pkt: &mut IncomingSshPacket, + ) -> Result<(), H::Error> { #[allow(clippy::indexing_slicing)] // length checked { trace!( "client_read_encrypted, buf = {:?}", - &buf[..buf.len().min(20)] + &pkt.buffer[..pkt.buffer.len().min(20)] ); } - // Either this packet is a KEXINIT, in which case we start a key re-exchange. - if buf.first() == Some(&msg::KEXINIT) { - debug!("Received KEXINIT"); - // Now, if we're encrypted: - if let Some(ref mut enc) = self.common.encrypted { - // If we're not currently re-keying, but buf is a rekey request - let kexinit = if let Some(Kex::Init(kexinit)) = enc.rekey.take() { - Some(kexinit) - } else if let Some(exchange) = std::mem::replace(&mut enc.exchange, None) { - Some(KexInit::received_rekey( - exchange, - negotiation::Client::read_kex(buf, &self.common.config.as_ref().preferred)?, - &enc.session_id, - )) - } else { - None - }; - - if let Some(kexinit) = kexinit { - let dhdone = kexinit.client_parse( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - buf, - &mut self.common.write_buffer, - )?; - - if !enc.kex.skip_exchange() { - enc.rekey = Some(Kex::DhDone(dhdone)); - } - } - } else { - unreachable!() - } - self.flush()?; - return Ok((client, self)); - } - if let Some(ref mut enc) = self.common.encrypted { - match enc.rekey.take() { - Some(Kex::DhDone(mut kexdhdone)) => { - return if kexdhdone.names.ignore_guessed { - kexdhdone.names.ignore_guessed = false; - enc.rekey = Some(Kex::DhDone(kexdhdone)); - Ok((client, self)) - } else if buf.first() == Some(&msg::KEX_ECDH_REPLY) { - // We've sent ECDH_INIT, waiting for ECDH_REPLY - let (kex, h) = kexdhdone.server_key_check(true, client, buf).await?; - client = h; - enc.rekey = Some(Kex::Keys(kex)); - self.common - .cipher - .local_to_remote - .write(&[msg::NEWKEYS], &mut self.common.write_buffer); - self.flush()?; - Ok((client, self)) - } else { - error!("Wrong packet received"); - Err(crate::Error::Inconsistent.into()) - }; - } - Some(Kex::Keys(newkeys)) => { - if buf.first() != Some(&msg::NEWKEYS) { - return Err(crate::Error::Kex.into()); - } - self.common.write_buffer.bytes = 0; - enc.last_rekey = std::time::Instant::now(); - - // Ok, NEWKEYS received, now encrypted. - enc.flush_all_pending(); - let mut pending = std::mem::take(&mut self.pending_reads); - for p in pending.drain(..) { - let (h, s) = self.process_packet(client, &p).await?; - self = s; - client = h; - } - self.pending_reads = pending; - self.pending_len = 0; - self.common.newkeys(newkeys); - self.flush()?; - return Ok((client, self)); - } - Some(Kex::Init(k)) => { - enc.rekey = Some(Kex::Init(k)); - self.pending_len += buf.len() as u32; - if self.pending_len > 2 * self.target_window_size { - return Err(crate::Error::Pending.into()); - } - self.pending_reads.push(CryptoVec::from_slice(buf)); - return Ok((client, self)); - } - rek => enc.rekey = rek, - } - } - self.process_packet(client, buf).await + self.process_packet(client, &pkt.buffer).await } - async fn process_packet( - mut self, - client: H, + pub(crate) async fn process_packet( + &mut self, + client: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { // If we've successfully read a packet. trace!("process_packet buf = {:?} bytes", buf.len()); - trace!("buf = {:?}", buf); let mut is_authenticated = false; if let Some(ref mut enc) = self.common.encrypted { match enc.state { @@ -162,200 +74,194 @@ impl Session { buf.first(), msg::SERVICE_ACCEPT ); - if buf.first() == Some(&msg::SERVICE_ACCEPT) { - let mut r = buf.reader(1); - if r.read_string().map_err(crate::Error::from)? == b"ssh-userauth" { - *accepted = true; - if let Some(ref meth) = self.common.auth_method { - let auth_request = match meth { - crate::auth::Method::KeyboardInteractive { submethods } => { - auth::AuthRequest { - methods: auth::MethodSet::all(), - partial_success: false, - current: Some( - auth::CurrentRequest::KeyboardInteractive { - submethods: submethods.to_string(), - }, - ), - rejection_count: 0, - } + match buf.split_first() { + Some((&msg::SERVICE_ACCEPT, mut r)) => { + if map_err!(Bytes::decode(&mut r))?.as_ref() == b"ssh-userauth" { + *accepted = true; + if let Some(ref meth) = self.common.auth_method { + let len = enc.write.len(); + let auth_request = AuthRequest::new(meth); + #[allow(clippy::indexing_slicing)] // length checked + if enc.write_auth_request(&self.common.auth_user, meth)? { + debug!("enc: {:?}", &enc.write[len..]); + enc.state = EncryptedState::WaitingAuthRequest(auth_request) } - _ => auth::AuthRequest { - methods: auth::MethodSet::all(), - partial_success: false, - current: None, - rejection_count: 0, - }, - }; - let len = enc.write.len(); - #[allow(clippy::indexing_slicing)] // length checked - if enc.write_auth_request(&self.common.auth_user, meth) { - debug!("enc: {:?}", &enc.write[len..]); - enc.state = EncryptedState::WaitingAuthRequest(auth_request) + } else { + debug!("no auth method") } - } else { - debug!("no auth method") } } - } else if buf.first() == Some(&msg::EXT_INFO) { - return self.handle_ext_info(client, buf); - } else { - debug!("unknown message: {:?}", buf); - return Err(crate::Error::Inconsistent.into()); + Some((&msg::EXT_INFO, mut r)) => { + return self.handle_ext_info(&mut r).map_err(Into::into); + } + other => { + debug!("unknown message: {other:?}"); + return Err(crate::Error::Inconsistent.into()); + } } } EncryptedState::WaitingAuthRequest(ref mut auth_request) => { - if buf.first() == Some(&msg::USERAUTH_SUCCESS) { - debug!("userauth_success"); - self.sender - .send(Reply::AuthSuccess) - .map_err(|_| crate::Error::SendError)?; - enc.state = EncryptedState::InitCompression; - enc.server_compression.init_decompress(&mut enc.decompress); - return Ok((client, self)); - } else if buf.first() == Some(&msg::USERAUTH_BANNER) { - let mut r = buf.reader(1); - let banner = r.read_string().map_err(crate::Error::from)?; - return if let Ok(banner) = std::str::from_utf8(banner) { - let (h, s) = client.auth_banner(banner, self).await?; - Ok((h, s)) - } else { - Ok((client, self)) - }; - } else if buf.first() == Some(&msg::USERAUTH_FAILURE) { - debug!("userauth_failure"); - - let mut r = buf.reader(1); - let remaining_methods = r.read_string().map_err(crate::Error::from)?; - debug!( - "remaining methods {:?}", - std::str::from_utf8(remaining_methods) - ); - auth_request.methods = auth::MethodSet::empty(); - for method in remaining_methods.split(|&c| c == b',') { - if let Some(m) = auth::MethodSet::from_bytes(method) { - auth_request.methods |= m - } + trace!("waiting auth request, {:?}", buf.first(),); + match buf.split_first() { + Some((&msg::USERAUTH_SUCCESS, _)) => { + debug!("userauth_success"); + self.sender + .send(Reply::AuthSuccess) + .map_err(|_| crate::Error::SendError)?; + enc.state = EncryptedState::InitCompression; + enc.server_compression.init_decompress(&mut enc.decompress); + return Ok(()); } - let no_more_methods = auth_request.methods.is_empty(); - self.common.auth_method = None; - self.sender - .send(Reply::AuthFailure) - .map_err(|_| crate::Error::SendError)?; - - // If no other authentication method is allowed by the server, give up. - if no_more_methods { - return Err(crate::Error::NoAuthMethod.into()); + Some((&msg::USERAUTH_BANNER, mut r)) => { + let banner = map_err!(String::decode(&mut r))?; + client.auth_banner(&banner, self).await?; + return Ok(()); } - } else if buf.first() == Some(&msg::USERAUTH_INFO_REQUEST_OR_USERAUTH_PK_OK) { - if let Some(auth::CurrentRequest::PublicKey { - ref mut sent_pk_ok, .. - }) = auth_request.current - { - debug!("userauth_pk_ok"); - *sent_pk_ok = true; - } else if let Some(auth::CurrentRequest::KeyboardInteractive { .. }) = - auth_request.current - { - debug!("keyboard_interactive"); - let mut r = buf.reader(1); - - // read fields - let name = String::from_utf8_lossy( - r.read_string().map_err(crate::Error::from)?, - ) - .to_string(); + Some((&msg::USERAUTH_FAILURE, mut r)) => { + debug!("userauth_failure"); - let instructions = String::from_utf8_lossy( - r.read_string().map_err(crate::Error::from)?, - ) - .to_string(); - - let _lang = r.read_string().map_err(crate::Error::from)?; - let n_prompts = r.read_u32().map_err(crate::Error::from)?; - - // read prompts - let mut prompts = Vec::with_capacity(n_prompts.try_into().unwrap_or(0)); - for _i in 0..n_prompts { - let prompt = String::from_utf8_lossy( - r.read_string().map_err(crate::Error::from)?, - ); - - let echo = r.read_byte().map_err(crate::Error::from)? != 0; - prompts.push(Prompt { - prompt: prompt.to_string(), - echo, - }); - } + let remaining_methods: MethodSet = + (&map_err!(NameList::decode(&mut r))?).into(); + let partial_success = map_err!(u8::decode(&mut r))? != 0; + debug!("remaining methods {remaining_methods:?}, partial success {partial_success:?}"); + auth_request.methods = remaining_methods.clone(); - // send challenges to caller + let no_more_methods = auth_request.methods.is_empty(); + self.common.auth_method = None; self.sender - .send(Reply::AuthInfoRequest { - name, - instructions, - prompts, + .send(Reply::AuthFailure { + proceed_with_methods: remaining_methods, + partial_success, }) .map_err(|_| crate::Error::SendError)?; - // wait for response from handler - let responses = loop { - match self.receiver.recv().await { - Some(Msg::AuthInfoResponse { responses }) => break responses, - _ => {} - } - }; - // write responses - enc.client_send_auth_response(&responses)?; - return Ok((client, self)); - } else { + // If no other authentication method is allowed by the server, give up. + if no_more_methods { + return Err(crate::Error::NoAuthMethod.into()); + } } + Some((&msg::USERAUTH_INFO_REQUEST_OR_USERAUTH_PK_OK, mut r)) => { + if let Some(auth::CurrentRequest::PublicKey { + ref mut sent_pk_ok, + .. + }) = auth_request.current + { + debug!("userauth_pk_ok"); + *sent_pk_ok = true; + } else if let Some(auth::CurrentRequest::KeyboardInteractive { + .. + }) = auth_request.current + { + debug!("keyboard_interactive"); - // continue with userauth_pk_ok - match self.common.auth_method.take() { - Some(auth_method @ auth::Method::PublicKey { .. }) => { - self.common.buffer.clear(); - enc.client_send_signature( - &self.common.auth_user, - &auth_method, - &mut self.common.buffer, - )? - } - Some(auth::Method::FuturePublicKey { key }) => { - debug!("public key"); - self.common.buffer.clear(); - let i = enc.client_make_to_sign( - &self.common.auth_user, - &key, - &mut self.common.buffer, - ); - let len = self.common.buffer.len(); - let buf = - std::mem::replace(&mut self.common.buffer, CryptoVec::new()); + // read fields + let name = map_err!(String::decode(&mut r))?; + + let instructions = map_err!(String::decode(&mut r))?; + + let _lang = map_err!(String::decode(&mut r))?; + let n_prompts = map_err!(u32::decode(&mut r))?; + + // read prompts + let mut prompts = + Vec::with_capacity(n_prompts.try_into().unwrap_or(0)); + for _i in 0..n_prompts { + let prompt = map_err!(String::decode(&mut r))?; + + let echo = map_err!(u8::decode(&mut r))? != 0; + prompts.push(Prompt { + prompt: prompt.to_string(), + echo, + }); + } + // send challenges to caller self.sender - .send(Reply::SignRequest { key, data: buf }) + .send(Reply::AuthInfoRequest { + name, + instructions, + prompts, + }) .map_err(|_| crate::Error::SendError)?; - self.common.buffer = loop { + + // wait for response from handler + let responses = loop { match self.receiver.recv().await { - Some(Msg::Signed { data }) => break data, + Some(Msg::AuthInfoResponse { responses }) => { + break responses + } + None => return Err(crate::Error::RecvError.into()), _ => {} } }; - if self.common.buffer.len() != len { - // The buffer was modified. - push_packet!(enc.write, { - #[allow(clippy::indexing_slicing)] // length checked - enc.write.extend(&self.common.buffer[i..]); - }) + // write responses + enc.client_send_auth_response(&responses)?; + return Ok(()); + } + + // continue with userauth_pk_ok + match self.common.auth_method.take() { + Some(auth_method @ auth::Method::PublicKey { .. }) => { + self.common.buffer.clear(); + enc.client_send_signature( + &self.common.auth_user, + &auth_method, + &mut self.common.buffer, + )? + } + Some(auth_method @ auth::Method::OpenSshCertificate { .. }) => { + self.common.buffer.clear(); + enc.client_send_signature( + &self.common.auth_user, + &auth_method, + &mut self.common.buffer, + )? + } + Some(auth::Method::FuturePublicKey { key, hash_alg }) => { + debug!("public key"); + self.common.buffer.clear(); + let i = enc.client_make_to_sign( + &self.common.auth_user, + &PublicKeyOrCertificate::PublicKey { + key: key.clone(), + hash_alg, + }, + &mut self.common.buffer, + )?; + let len = self.common.buffer.len(); + let buf = std::mem::replace( + &mut self.common.buffer, + CryptoVec::new(), + ); + + self.sender + .send(Reply::SignRequest { key, data: buf }) + .map_err(|_| crate::Error::SendError)?; + self.common.buffer = loop { + match self.receiver.recv().await { + Some(Msg::Signed { data }) => break data, + None => return Err(crate::Error::RecvError.into()), + _ => {} + } + }; + if self.common.buffer.len() != len { + // The buffer was modified. + push_packet!(enc.write, { + #[allow(clippy::indexing_slicing)] // length checked + enc.write.extend(&self.common.buffer[i..]); + }) + } } + _ => {} } - _ => {} } - } else if buf.first() == Some(&msg::EXT_INFO) { - return self.handle_ext_info(client, buf); - } else { - debug!("unknown message: {:?}", buf); - return Err(crate::Error::Inconsistent.into()); + Some((&msg::EXT_INFO, mut r)) => { + return self.handle_ext_info(&mut r).map_err(Into::into); + } + other => { + debug!("unknown message: {other:?}"); + return Err(crate::Error::Inconsistent.into()); + } } } EncryptedState::InitCompression => unreachable!(), @@ -365,25 +271,57 @@ impl Session { if is_authenticated { self.client_read_authenticated(client, buf).await } else { - Ok((client, self)) + Ok(()) } } - fn handle_ext_info(self, client: H, buf: &[u8]) -> Result<(H, Self), H::Error> { - debug!("Received EXT_INFO: {:?}", buf); - Ok((client, self)) + fn handle_ext_info(&mut self, r: &mut impl Reader) -> Result<(), Error> { + let n_extensions = u32::decode(r)? as usize; + debug!("Received EXT_INFO, {n_extensions:?} extensions"); + for _ in 0..n_extensions { + let name = String::decode(r)?; + if name == "server-sig-algs" { + self.handle_server_sig_algs_ext(r)?; + } else { + let data = Vec::::decode(r)?; + debug!("* {name:?} (unknown, data: {data:?})"); + } + if let Some(ref mut enc) = self.common.encrypted { + enc.received_extensions.push(name.clone()); + if let Some(mut senders) = enc.extension_info_awaiters.remove(&name) { + senders.drain(..).for_each(|w| { + let _ = w.send(()); + }); + } + } + } + Ok(()) + } + + fn handle_server_sig_algs_ext(&mut self, r: &mut impl Reader) -> Result<(), Error> { + let algs = NameList::decode(r)?; + debug!("* server-sig-algs"); + self.server_sig_algs = Some( + algs.0 + .iter() + .filter_map(|x| Algorithm::from_str(x).ok()) + .inspect(|x| { + debug!(" * {x:?}"); + }) + .collect::>(), + ); + Ok(()) } async fn client_read_authenticated( - mut self, - mut client: H, + &mut self, + client: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { - match buf.first() { - Some(&msg::CHANNEL_OPEN_CONFIRMATION) => { + ) -> Result<(), H::Error> { + match buf.split_first() { + Some((&msg::CHANNEL_OPEN_CONFIRMATION, mut reader)) => { debug!("channel_open_confirmation"); - let mut reader = buf.reader(1); - let msg = ChannelOpenConfirmation::parse(&mut reader)?; + let msg = map_err!(ChannelOpenConfirmation::decode(&mut reader))?; let local_id = ChannelId(msg.recipient_channel); if let Some(ref mut enc) = self.common.encrypted { @@ -404,6 +342,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {local_id:?}"); @@ -418,60 +357,53 @@ impl Session { ) .await } - Some(&msg::CHANNEL_CLOSE) => { + Some((&msg::CHANNEL_CLOSE, mut r)) => { debug!("channel_close"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(ref mut enc) = self.common.encrypted { // The CHANNEL_CLOSE message must be sent to the server at this point or the session // will not be released. - enc.close(channel_num); + enc.close(channel_num)?; } self.channels.remove(&channel_num); client.channel_close(channel_num, self).await } - Some(&msg::CHANNEL_EOF) => { + Some((&msg::CHANNEL_EOF, mut r)) => { debug!("channel_eof"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Eof); + let _ = chan.send(ChannelMsg::Eof).await; } client.channel_eof(channel_num, self).await } - Some(&msg::CHANNEL_OPEN_FAILURE) => { + Some((&msg::CHANNEL_OPEN_FAILURE, mut r)) => { debug!("channel_open_failure"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let reason_code = - ChannelOpenFailure::from_u32(r.read_u32().map_err(crate::Error::from)?) - .unwrap_or(ChannelOpenFailure::Unknown); - let descr = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let language = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let reason_code = ChannelOpenFailure::from_u32(map_err!(u32::decode(&mut r))?) + .unwrap_or(ChannelOpenFailure::Unknown); + let descr = map_err!(String::decode(&mut r))?; + let language = map_err!(String::decode(&mut r))?; if let Some(ref mut enc) = self.common.encrypted { enc.channels.remove(&channel_num); } if let Some(sender) = self.channels.remove(&channel_num) { - let _ = sender.send(ChannelMsg::OpenFailure(reason_code)); + let _ = sender.send(ChannelMsg::OpenFailure(reason_code)).await; } let _ = self.sender.send(Reply::ChannelOpenFailure); client - .channel_open_failure(channel_num, reason_code, descr, language, self) + .channel_open_failure(channel_num, reason_code, &descr, &language, self) .await } - Some(&msg::CHANNEL_DATA) => { + Some((&msg::CHANNEL_DATA, mut r)) => { trace!("channel_data"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let data = r.read_string().map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let data = map_err!(Bytes::decode(&mut r))?; let target = self.common.config.window_size; if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, data, target) { + if enc.adjust_window_size(channel_num, &data, target)? { let next_window = client.adjust_window(channel_num, self.target_window_size); if next_window > 0 { @@ -481,22 +413,23 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(data), - }); + let _ = chan + .send(ChannelMsg::Data { + data: CryptoVec::from_slice(&data), + }) + .await; } - client.data(channel_num, data, self).await + client.data(channel_num, &data, self).await } - Some(&msg::CHANNEL_EXTENDED_DATA) => { + Some((&msg::CHANNEL_EXTENDED_DATA, mut r)) => { debug!("channel_extended_data"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let extended_code = r.read_u32().map_err(crate::Error::from)?; - let data = r.read_string().map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let extended_code = map_err!(u32::decode(&mut r))?; + let data = map_err!(Bytes::decode(&mut r))?; let target = self.common.config.window_size; if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, data, target) { + if enc.adjust_window_size(channel_num, &data, target)? { let next_window = client.adjust_window(channel_num, self.target_window_size); if next_window > 0 { @@ -506,118 +439,104 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExtendedData { - ext: extended_code, - data: CryptoVec::from_slice(data), - }); + let _ = chan + .send(ChannelMsg::ExtendedData { + ext: extended_code, + data: CryptoVec::from_slice(&data), + }) + .await; } client - .extended_data(channel_num, extended_code, data, self) + .extended_data(channel_num, extended_code, &data, self) .await } - Some(&msg::CHANNEL_REQUEST) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let req = r.read_string().map_err(crate::Error::from)?; - debug!( - "channel_request: {:?} {:?}", - channel_num, - std::str::from_utf8(req) - ); - match req { - b"xon-xoff" => { - r.read_byte().map_err(crate::Error::from)?; // should be 0. - let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0; + Some((&msg::CHANNEL_REQUEST, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let req = map_err!(String::decode(&mut r))?; + debug!("channel_request: {channel_num:?} {req:?}",); + match req.as_str() { + "xon-xoff" => { + map_err!(u8::decode(&mut r))?; // should be 0. + let client_can_do = map_err!(u8::decode(&mut r))? != 0; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::XonXoff { client_can_do }); + let _ = chan.send(ChannelMsg::XonXoff { client_can_do }).await; } client.xon_xoff(channel_num, client_can_do, self).await } - b"exit-status" => { - r.read_byte().map_err(crate::Error::from)?; // should be 0. - let exit_status = r.read_u32().map_err(crate::Error::from)?; + "exit-status" => { + map_err!(u8::decode(&mut r))?; // should be 0. + let exit_status = map_err!(u32::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitStatus { exit_status }); + let _ = chan.send(ChannelMsg::ExitStatus { exit_status }).await; } client.exit_status(channel_num, exit_status, self).await } - b"exit-signal" => { - r.read_byte().map_err(crate::Error::from)?; // should be 0. + "exit-signal" => { + map_err!(u8::decode(&mut r))?; // should be 0. let signal_name = - Sig::from_name(r.read_string().map_err(crate::Error::from)?)?; - let core_dumped = r.read_byte().map_err(crate::Error::from)? != 0; - let error_message = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let lang_tag = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + Sig::from_name(map_err!(String::decode(&mut r))?.as_str()); + let core_dumped = map_err!(u8::decode(&mut r))? != 0; + let error_message = map_err!(String::decode(&mut r))?; + let lang_tag = map_err!(String::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitSignal { - signal_name: signal_name.clone(), - core_dumped, - error_message: error_message.to_string(), - lang_tag: lang_tag.to_string(), - }); + let _ = chan + .send(ChannelMsg::ExitSignal { + signal_name: signal_name.clone(), + core_dumped, + error_message: error_message.to_string(), + lang_tag: lang_tag.to_string(), + }) + .await; } client .exit_signal( channel_num, signal_name, core_dumped, - error_message, - lang_tag, + &error_message, + &lang_tag, self, ) .await } - b"keepalive@openssh.com" => { - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + "keepalive@openssh.com" => { + let wants_reply = map_err!(u8::decode(&mut r))?; if wants_reply == 1 { if let Some(ref mut enc) = self.common.encrypted { - trace!( - "Received channel keep alive message: {:?}", - std::str::from_utf8(req), - ); + trace!("Received channel keep alive message: {req:?}",); self.common.wants_reply = false; push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_SUCCESS); - enc.write.push_u32_be(channel_num.0) + map_err!(msg::CHANNEL_SUCCESS.encode(&mut enc.write))?; + map_err!(channel_num.encode(&mut enc.write))?; }); } } else { warn!("Received keepalive without reply request!"); } - Ok((client, self)) + Ok(()) } _ => { - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + let wants_reply = map_err!(u8::decode(&mut r))?; if wants_reply == 1 { if let Some(ref mut enc) = self.common.encrypted { self.common.wants_reply = false; push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_FAILURE); - enc.write.push_u32_be(channel_num.0) + map_err!(msg::CHANNEL_FAILURE.encode(&mut enc.write))?; + map_err!(channel_num.encode(&mut enc.write))?; }) } } - info!( - "Unknown channel request {:?} {:?}", - std::str::from_utf8(req), - wants_reply - ); - Ok((client, self)) + info!("Unknown channel request {req:?} {wants_reply:?}",); + Ok(()) } } } - Some(&msg::CHANNEL_WINDOW_ADJUST) => { - debug!("channel_window_adjust"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let amount = r.read_u32().map_err(crate::Error::from)?; + Some((&msg::CHANNEL_WINDOW_ADJUST, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let amount = map_err!(u32::decode(&mut r))?; let mut new_size = 0; - debug!("amount: {:?}", amount); + debug!("channel_window_adjust amount: {:?}", amount); if let Some(ref mut enc) = self.common.encrypted { if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) { channel.recipient_window_size += amount; @@ -628,87 +547,73 @@ impl Session { } if let Some(ref mut enc) = self.common.encrypted { - new_size -= enc.flush_pending(channel_num) as u32; + new_size -= enc.flush_pending(channel_num)? as u32; } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); + chan.window_size().update(new_size).await; + + let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }).await; } client.window_adjusted(channel_num, new_size, self).await } - Some(&msg::GLOBAL_REQUEST) => { - let mut r = buf.reader(1); - let req = r.read_string().map_err(crate::Error::from)?; - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + Some((&msg::GLOBAL_REQUEST, mut r)) => { + let req = map_err!(String::decode(&mut r))?; + let wants_reply = map_err!(u8::decode(&mut r))?; if let Some(ref mut enc) = self.common.encrypted { - if req.starts_with(b"keepalive") { + if req.starts_with("keepalive") { if wants_reply == 1 { - trace!( - "Received keep alive message: {:?}", - std::str::from_utf8(req), - ); + trace!("Received keep alive message: {req:?}",); self.common.wants_reply = false; push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)); } else { warn!("Received keepalive without reply request!"); } - } else if req == b"hostkeys-00@openssh.com" { + } else if req == "hostkeys-00@openssh.com" { let mut keys = vec![]; loop { - match r.read_string() { + match Bytes::decode(&mut r) { Ok(key) => { - let key2 = <&[u8]>::clone(&key); - #[cfg(not(feature = "openssl"))] - let key = parse_public_key(key).map_err(crate::Error::from); - #[cfg(feature = "openssl")] - let key = - parse_public_key(key, None).map_err(crate::Error::from); + let key = map_err!(parse_public_key(&key)); match key { Ok(key) => keys.push(key), - Err(err) => { + Err(ref err) => { debug!( - "failed to parse announced host key {:?}: {:?}", - key2, err + "failed to parse announced host key {key:?}: {err:?}", ) } } } - Err(russh_keys::Error::IndexOutOfBounds) => break, + Err(ssh_encoding::Error::Length) => break, x => { - x.map_err(crate::Error::from)?; + map_err!(x)?; } } } return client.openssh_ext_host_keys_announced(keys, self).await; } else { - warn!( - "Unhandled global request: {:?} {:?}", - std::str::from_utf8(req), - wants_reply - ); + warn!("Unhandled global request: {req:?} {wants_reply:?}",); self.common.wants_reply = false; push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - Ok((client, self)) + self.common.received_data = false; + Ok(()) } - Some(&msg::CHANNEL_SUCCESS) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + Some((&msg::CHANNEL_SUCCESS, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Success); + let _ = chan.send(ChannelMsg::Success).await; } client.channel_success(channel_num, self).await } - Some(&msg::CHANNEL_FAILURE) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + Some((&msg::CHANNEL_FAILURE, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Failure); + let _ = chan.send(ChannelMsg::Failure).await; } client.channel_failure(channel_num, self).await } - Some(&msg::CHANNEL_OPEN) => { - let mut r = buf.reader(1); + Some((&msg::CHANNEL_OPEN, mut r)) => { let msg = OpenChannelMessage::parse(&mut r)?; if let Some(ref mut enc) = self.common.encrypted { @@ -723,29 +628,34 @@ impl Session { confirmed: true, wants_reply: false, pending_data: std::collections::VecDeque::new(), + pending_eof: false, + pending_close: false, }; let confirm = || { debug!("confirming channel: {:?}", msg); - msg.confirm( + map_err!(msg.confirm( &mut enc.write, id.0, channel.sender_window_size, channel.sender_maximum_packet_size, - ); + ))?; enc.channels.insert(id, channel); + Ok(()) }; - Ok(match &msg.typ { + match &msg.typ { ChannelType::Session => { - confirm(); - client.server_channel_open_session(id, self).await? + confirm()?; + let channel = self.accept_server_initiated_channel(id, &msg); + client.server_channel_open_session(channel, self).await? } ChannelType::DirectTcpip(d) => { - confirm(); + confirm()?; + let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_direct_tcpip( - id, + channel, &d.host_to_connect, d.port_to_connect, &d.originator_address, @@ -754,11 +664,22 @@ impl Session { ) .await? } + ChannelType::DirectStreamLocal(d) => { + confirm()?; + let channel = self.accept_server_initiated_channel(id, &msg); + client + .server_channel_open_direct_streamlocal( + channel, + &d.socket_path, + self, + ) + .await? + } ChannelType::X11 { originator_address, originator_port, } => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_x11( @@ -770,7 +691,7 @@ impl Session { .await? } ChannelType::ForwardedTcpIp(d) => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_forwarded_tcpip( @@ -783,27 +704,109 @@ impl Session { ) .await? } + ChannelType::ForwardedStreamLocal(d) => { + confirm()?; + let channel = self.accept_server_initiated_channel(id, &msg); + client + .server_channel_open_forwarded_streamlocal( + channel, + &d.socket_path, + self, + ) + .await?; + } ChannelType::AgentForward => { - confirm(); - client.server_channel_open_agent_forward(id, self).await? + confirm()?; + let channel = self.accept_server_initiated_channel(id, &msg); + client + .server_channel_open_agent_forward(channel, self) + .await? } ChannelType::Unknown { typ } => { - if client.server_channel_handle_unknown(id, typ) { - confirm(); + if client.should_accept_unknown_server_channel(id, typ).await { + confirm()?; + let channel = self.accept_server_initiated_channel(id, &msg); + client.server_channel_open_unknown(channel, self).await?; } else { - debug!("unknown channel type: {}", String::from_utf8_lossy(typ)); - msg.unknown_type(&mut enc.write); + debug!("unknown channel type: {typ}"); + msg.unknown_type(&mut enc.write)?; } - (client, self) } - }) + }; + Ok(()) } else { Err(crate::Error::Inconsistent.into()) } } - _ => { - info!("Unhandled packet: {:?}", buf); - Ok((client, self)) + Some((&msg::REQUEST_SUCCESS, mut r)) => { + trace!("Global Request Success"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::NoMoreSessions) => { + debug!("no-more-sessions@openssh.com requests success"); + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let result = if r.is_empty() { + // If a specific port was requested, the reply has no data + Some(0) + } else { + match u32::decode(&mut r) { + Ok(port) => Some(port), + Err(e) => { + error!("Error parsing port for TcpIpForward request: {e:?}"); + None + } + } + }; + let _ = return_channel.send(result); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(true); + } + Some(GlobalRequestResponse::StreamLocalForward(return_channel)) => { + let _ = return_channel.send(true); + } + Some(GlobalRequestResponse::CancelStreamLocalForward(return_channel)) => { + let _ = return_channel.send(true); + } + None => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) + } + Some((&msg::REQUEST_FAILURE, _)) => { + trace!("global request failure"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::NoMoreSessions) => { + warn!("no-more-sessions@openssh.com requests failure"); + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let _ = return_channel.send(None); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(false); + } + Some(GlobalRequestResponse::StreamLocalForward(return_channel)) => { + let _ = return_channel.send(false); + } + Some(GlobalRequestResponse::CancelStreamLocalForward(return_channel)) => { + let _ = return_channel.send(false); + } + None => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) + } + m => { + debug!("unknown message received: {:?}", m); + Ok(()) } } } @@ -813,18 +816,24 @@ impl Session { id: ChannelId, msg: &OpenChannelMessage, ) -> Channel { - let (sender, receiver) = unbounded_channel(); - self.channels.insert(id, sender); - Channel { + let (channel, channel_ref) = Channel::new( id, - sender: self.inbound_channel_sender.clone(), - receiver, - max_packet_size: msg.recipient_maximum_packet_size, - window_size: msg.recipient_window_size, - } + self.inbound_channel_sender.clone(), + msg.recipient_maximum_packet_size, + msg.recipient_window_size, + self.common.config.channel_buffer_size, + ); + + self.channels.insert(id, channel_ref); + + channel } - pub(crate) fn write_auth_request_if_needed(&mut self, user: &str, meth: auth::Method) -> bool { + pub(crate) fn write_auth_request_if_needed( + &mut self, + user: &str, + meth: auth::Method, + ) -> Result { let mut is_waiting = false; if let Some(ref mut enc) = self.common.encrypted { is_waiting = match enc.state { @@ -835,11 +844,11 @@ impl Session { } => { debug!("sending ssh-userauth service requset"); if !*sent { - let p = b"\x05\0\0\0\x0Cssh-userauth"; - self.common - .cipher - .local_to_remote - .write(p, &mut self.common.write_buffer); + self.common.packet_writer.packet(|w| { + msg::SERVICE_REQUEST.encode(w)?; + "ssh-userauth".encode(w)?; + Ok(()) + })?; *sent = true } accepted @@ -851,89 +860,121 @@ impl Session { is_waiting ); if is_waiting { - enc.write_auth_request(user, &meth); + enc.write_auth_request(user, &meth)?; + let auth_request = AuthRequest::new(&meth); + enc.state = EncryptedState::WaitingAuthRequest(auth_request); } } self.common.auth_user.clear(); self.common.auth_user.push_str(user); self.common.auth_method = Some(meth); - is_waiting + Ok(is_waiting) } } impl Encrypted { - fn write_auth_request(&mut self, user: &str, auth_method: &auth::Method) -> bool { + fn write_auth_request( + &mut self, + user: &str, + auth_method: &auth::Method, + ) -> Result { // The server is waiting for our USERAUTH_REQUEST. - push_packet!(self.write, { + Ok(push_packet!(self.write, { self.write.push(msg::USERAUTH_REQUEST); match *auth_method { auth::Method::None => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"none"); + user.encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "none".encode(&mut self.write)?; true } auth::Method::Password { ref password } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"password"); - self.write.push(0); - self.write.extend_ssh_string(password.as_bytes()); + user.encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "password".encode(&mut self.write)?; + 0u8.encode(&mut self.write)?; + password.encode(&mut self.write)?; true } auth::Method::PublicKey { ref key } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"publickey"); + user.encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "publickey".encode(&mut self.write)?; self.write.push(0); // This is a probe - debug!("write_auth_request: {:?}", key.name()); - self.write.extend_ssh_string(key.name().as_bytes()); - key.push_to(&mut self.write); + debug!("write_auth_request: key - {:?}", key.algorithm()); + key.algorithm().as_str().encode(&mut self.write)?; + key.public_key().to_bytes()?.encode(&mut self.write)?; true } - auth::Method::FuturePublicKey { ref key, .. } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"publickey"); + auth::Method::OpenSshCertificate { ref cert, .. } => { + user.as_bytes().encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "publickey".encode(&mut self.write)?; self.write.push(0); // This is a probe - self.write.extend_ssh_string(key.name().as_bytes()); - key.push_to(&mut self.write); + debug!("write_auth_request: cert - {:?}", cert.algorithm()); + cert.algorithm() + .to_certificate_type() + .encode(&mut self.write)?; + cert.to_bytes()?.as_slice().encode(&mut self.write)?; + true + } + auth::Method::FuturePublicKey { ref key, hash_alg } => { + user.as_bytes().encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "publickey".encode(&mut self.write)?; + self.write.push(0); // This is a probe + + key.algorithm() + .with_hash_alg(hash_alg) + .as_str() + .encode(&mut self.write)?; + + key.to_bytes()?.as_slice().encode(&mut self.write)?; true } auth::Method::KeyboardInteractive { ref submethods } => { - debug!("Keyboard Iinteractive"); - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"keyboard-interactive"); - self.write.extend_ssh_string(b""); // lang tag is deprecated. Should be empty - self.write.extend_ssh_string(submethods.as_bytes()); + debug!("Keyboard interactive"); + user.as_bytes().encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "keyboard-interactive".encode(&mut self.write)?; + "".encode(&mut self.write)?; // lang tag is deprecated. Should be empty + submethods.as_bytes().encode(&mut self.write)?; true } } - }) + })) } - fn client_make_to_sign( + fn client_make_to_sign( &mut self, user: &str, - key: &Key, + key: &PublicKeyOrCertificate, buffer: &mut CryptoVec, - ) -> usize { + ) -> Result { buffer.clear(); - buffer.extend_ssh_string(self.session_id.as_ref()); + self.session_id.as_ref().encode(buffer)?; let i0 = buffer.len(); buffer.push(msg::USERAUTH_REQUEST); - buffer.extend_ssh_string(user.as_bytes()); - buffer.extend_ssh_string(b"ssh-connection"); - buffer.extend_ssh_string(b"publickey"); - buffer.push(1); - buffer.extend_ssh_string(key.name().as_bytes()); - key.push_to(buffer); - i0 + user.encode(buffer)?; + "ssh-connection".encode(buffer)?; + "publickey".encode(buffer)?; + 1u8.encode(buffer)?; + + match key { + PublicKeyOrCertificate::Certificate(cert) => { + cert.algorithm().to_certificate_type().encode(buffer)?; + cert.to_bytes()?.encode(buffer)?; + } + PublicKeyOrCertificate::PublicKey { key, hash_alg } => { + key.algorithm().with_hash_alg(*hash_alg).encode(buffer)?; + key.to_bytes()?.encode(buffer)?; + } + } + Ok(i0) } fn client_send_signature( @@ -943,10 +984,30 @@ impl Encrypted { buffer: &mut CryptoVec, ) -> Result<(), crate::Error> { match method { - auth::Method::PublicKey { ref key } => { - let i0 = self.client_make_to_sign(user, key.as_ref(), buffer); + auth::Method::PublicKey { key } => { + let i0 = + self.client_make_to_sign(user, &PublicKeyOrCertificate::from(key), buffer)?; + // Extend with self-signature. - key.add_self_signature(buffer)?; + sign_with_hash_alg(key, buffer)?.encode(&mut *buffer)?; + + push_packet!(self.write, { + #[allow(clippy::indexing_slicing)] // length checked + self.write.extend(&buffer[i0..]); + }) + } + auth::Method::OpenSshCertificate { ref key, ref cert } => { + let i0 = self.client_make_to_sign( + user, + &PublicKeyOrCertificate::Certificate(cert.clone()), + buffer, + )?; + + // Extend with self-signature. + signature::Signer::try_sign(key.deref(), buffer)? + .encoded()? + .encode(&mut *buffer)?; + push_packet!(self.write, { #[allow(clippy::indexing_slicing)] // length checked self.write.extend(&buffer[i0..]); @@ -959,12 +1020,11 @@ impl Encrypted { fn client_send_auth_response(&mut self, responses: &[String]) -> Result<(), crate::Error> { push_packet!(self.write, { - self.write.push(msg::USERAUTH_INFO_RESPONSE); - self.write - .push_u32_be(responses.len().try_into().unwrap_or(0)); // number of responses + msg::USERAUTH_INFO_RESPONSE.encode(&mut self.write)?; + (responses.len().try_into().unwrap_or(0) as u32).encode(&mut self.write)?; // number of responses for r in responses { - self.write.extend_ssh_string(r.as_bytes()); // write the reponses + r.encode(&mut self.write)?; // write the reponses } }); Ok(()) diff --git a/russh/src/client/kex.rs b/russh/src/client/kex.rs index afd5ae62..408612b0 100644 --- a/russh/src/client/kex.rs +++ b/russh/src/client/kex.rs @@ -1,77 +1,372 @@ -use log::{debug, trace}; - -use crate::cipher::SealingKey; -use crate::client::Config; -use crate::kex::KEXES; -use crate::negotiation; -use crate::negotiation::Select; -use crate::session::{KexDhDone, KexInit}; -use crate::sshbuffer::SSHBuffer; - -impl KexInit { - pub fn client_parse( - mut self, - config: &Config, - cipher: &mut dyn SealingKey, - buf: &[u8], - write_buffer: &mut SSHBuffer, - ) -> Result { - trace!("client parse {:?} {:?}", buf.len(), buf); - let algo = { - // read algorithms from packet. - debug!("extending {:?}", &self.exchange.server_kex_init[..]); - self.exchange.server_kex_init.extend(buf); - negotiation::Client::read_kex(buf, &config.preferred)? - }; - debug!("algo = {:?}", algo); - debug!("write = {:?}", &write_buffer.buffer[..]); - if !self.sent { - self.client_write(config, cipher, write_buffer)? +use core::fmt; +use std::cell::RefCell; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use bytes::Bytes; +use log::{debug, error, warn}; +use signature::Verifier; +use ssh_encoding::{Decode, Encode}; +use ssh_key::{Mpint, PublicKey, Signature}; + +use super::IncomingSshPacket; +use crate::client::{Config, NewKeys}; +use crate::kex::dh::groups::DhGroup; +use crate::kex::{KexAlgorithm, KexAlgorithmImplementor, KexCause, KexProgress, KEXES}; +use crate::keys::key::parse_public_key; +use crate::negotiation::{Names, Select}; +use crate::session::Exchange; +use crate::sshbuffer::PacketWriter; +use crate::{msg, negotiation, strict_kex_violation, CryptoVec, Error, SshId}; + +thread_local! { + static HASH_BUFFER: RefCell = RefCell::new(CryptoVec::new()); +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum ClientKexState { + Created, + WaitingForGexReply { + names: Names, + kex: KexAlgorithm, + }, + WaitingForDhReply { + // both KexInit and DH init sent + names: Names, + kex: KexAlgorithm, + }, + WaitingForNewKeys { + server_host_key: PublicKey, + newkeys: NewKeys, + }, +} + +pub(crate) struct ClientKex { + exchange: Exchange, + cause: KexCause, + state: ClientKexState, + config: Arc, +} + +impl Debug for ClientKex { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("ClientKex"); + s.field("cause", &self.cause); + match self.state { + ClientKexState::Created => { + s.field("state", &"created"); + } + ClientKexState::WaitingForGexReply { .. } => { + s.field("state", &"waiting for GEX response"); + } + ClientKexState::WaitingForDhReply { .. } => { + s.field("state", &"waiting for DH response"); + } + ClientKexState::WaitingForNewKeys { .. } => { + s.field("state", &"waiting for NEWKEYS"); + } } + s.finish() + } +} - // This function is called from the public API. - // - // In order to simplify the public API, we reuse the - // self.exchange.client_kex buffer to send an extra packet, - // then truncate that buffer. Without that, we would need an - // extra buffer. - let i0 = self.exchange.client_kex_init.len(); - debug!("i0 = {:?}", i0); - - let mut kex = KEXES - .get(&algo.kex) - .ok_or(crate::Error::UnknownAlgo)? - .make(); - - kex.client_dh( - &mut self.exchange.client_ephemeral, - &mut self.exchange.client_kex_init, - )?; - - #[allow(clippy::indexing_slicing)] // length checked - cipher.write(&self.exchange.client_kex_init[i0..], write_buffer); - self.exchange.client_kex_init.resize(i0); - - debug!("moving to kexdhdone, exchange = {:?}", self.exchange); - Ok(KexDhDone { - exchange: self.exchange, - names: algo, - kex, - key: 0, - session_id: self.session_id, - }) +impl ClientKex { + pub fn new( + config: Arc, + client_sshid: &SshId, + server_sshid: &[u8], + cause: KexCause, + ) -> Self { + let exchange = Exchange::new(client_sshid.as_kex_hash_bytes(), server_sshid); + Self { + config, + exchange, + cause, + state: ClientKexState::Created, + } } - pub fn client_write( - &mut self, - config: &Config, - cipher: &mut dyn SealingKey, - write_buffer: &mut SSHBuffer, - ) -> Result<(), crate::Error> { - self.exchange.client_kex_init.clear(); - negotiation::write_kex(&config.preferred, &mut self.exchange.client_kex_init, false)?; - self.sent = true; - cipher.write(&self.exchange.client_kex_init, write_buffer); + pub fn kexinit(&mut self, output: &mut PacketWriter) -> Result<(), Error> { + self.exchange.client_kex_init = + negotiation::write_kex(&self.config.preferred, output, None)?; + Ok(()) } + + pub fn step( + mut self, + input: Option<&mut IncomingSshPacket>, + output: &mut PacketWriter, + ) -> Result, Error> { + match self.state { + ClientKexState::Created => { + // At this point we expect to read the KEXINIT from the other side + + let Some(input) = input else { + return Err(Error::KexInit); + }; + if input.buffer.first() != Some(&msg::KEXINIT) { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::KexInit); + } + + let names = { + // read algorithms from packet. + self.exchange.server_kex_init.extend(&input.buffer); + negotiation::Client::read_kex(&input.buffer, &self.config.preferred, None)? + }; + debug!("negotiated algorithms: {names:?}"); + + // seqno has already been incremented after read() + if self.cause.is_strict_kex(&names) && !self.cause.is_rekey() && input.seqn.0 != 1 { + return Err(strict_kex_violation( + msg::KEXINIT, + input.seqn.0 as usize - 1, + )); + } + + let mut kex = KEXES.get(&names.kex).ok_or(Error::UnknownAlgo)?.make(); + + if kex.skip_exchange() { + // Non-standard no-kex exchange + let newkeys = compute_keys( + CryptoVec::new(), + kex, + names.clone(), + self.exchange.clone(), + self.cause.session_id(), + )?; + + output.packet(|w| { + msg::NEWKEYS.encode(w)?; + Ok(()) + })?; + + return Ok(KexProgress::Done { + newkeys, + server_host_key: None, + }); + } + + if kex.is_dh_gex() { + output.packet(|w| { + kex.client_dh_gex_init(&self.config.gex, w)?; + Ok(()) + })?; + + self.state = ClientKexState::WaitingForGexReply { names, kex }; + } else { + output.packet(|w| { + kex.client_dh(&mut self.exchange.client_ephemeral, w)?; + Ok(()) + })?; + + self.state = ClientKexState::WaitingForDhReply { names, kex }; + } + + Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn: false, + }) + } + ClientKexState::WaitingForGexReply { names, mut kex } => { + let Some(input) = input else { + return Err(Error::KexInit); + }; + + if input.buffer.first() != Some(&msg::KEX_DH_GEX_GROUP) { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::KexInit); + } + + #[allow(clippy::indexing_slicing)] // length checked + let mut r = &input.buffer[1..]; + + let prime = Mpint::decode(&mut r)?; + let gen = Mpint::decode(&mut r)?; + debug!("received gex group: prime={}, gen={}", prime, gen); + + let group = DhGroup { + prime: prime.as_bytes().to_vec().into(), + generator: gen.as_bytes().to_vec().into(), + }; + + if group.bit_size() < self.config.gex.min_group_size + || group.bit_size() > self.config.gex.max_group_size + { + warn!( + "DH prime size ({} bits) not within requested range", + group.bit_size() + ); + return Err(Error::KexInit); + } + + let exchange = &mut self.exchange; + exchange.gex = Some((self.config.gex.clone(), group.clone())); + kex.dh_gex_set_group(group)?; + output.packet(|w| { + kex.client_dh(&mut exchange.client_ephemeral, w)?; + Ok(()) + })?; + self.state = ClientKexState::WaitingForDhReply { names, kex }; + + Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn: false, + }) + } + ClientKexState::WaitingForDhReply { mut names, mut kex } => { + // At this point, we've sent ECDH_INTI and + // are waiting for the ECDH_REPLY from the server. + + let Some(input) = input else { + return Err(Error::KexInit); + }; + + if names.ignore_guessed { + // Ignore the next packet if (1) it follows and (2) it's not the correct guess. + debug!("ignoring guessed kex"); + names.ignore_guessed = false; + self.state = ClientKexState::WaitingForDhReply { names, kex }; + return Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn: false, + }); + } + + if input.buffer.first() + != Some(match kex.is_dh_gex() { + true => &msg::KEX_DH_GEX_REPLY, + false => &msg::KEX_ECDH_REPLY, + }) + { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::KexInit); + } + + #[allow(clippy::indexing_slicing)] // length checked + let r = &mut &input.buffer[1..]; + + let server_host_key = Bytes::decode(r)?; // server public key. + let server_host_key = parse_public_key(&server_host_key)?; + debug!( + "received server host key: {:?}", + server_host_key.to_openssh() + ); + + let server_ephemeral = Bytes::decode(r)?; + self.exchange.server_ephemeral.extend(&server_ephemeral); + kex.compute_shared_secret(&self.exchange.server_ephemeral)?; + + let mut pubkey_vec = CryptoVec::new(); + server_host_key.to_bytes()?.encode(&mut pubkey_vec)?; + + let exchange = &self.exchange; + let hash = HASH_BUFFER.with({ + |buffer| { + let mut buffer = buffer.borrow_mut(); + buffer.clear(); + kex.compute_exchange_hash(&pubkey_vec, exchange, &mut buffer) + } + })?; + + let signature = Bytes::decode(r)?; + let signature = Signature::decode(&mut &signature[..])?; + + if let Err(e) = Verifier::verify(&server_host_key, hash.as_ref(), &signature) { + debug!("wrong server sig: {e:?}"); + return Err(Error::WrongServerSig); + } + + let newkeys = compute_keys( + hash, + kex, + names.clone(), + self.exchange.clone(), + self.cause.session_id(), + )?; + + output.packet(|w| { + msg::NEWKEYS.encode(w)?; + Ok(()) + })?; + + let reset_seqn = newkeys.names.strict_kex; + + self.state = ClientKexState::WaitingForNewKeys { + server_host_key, + newkeys, + }; + + Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn, + }) + } + ClientKexState::WaitingForNewKeys { + server_host_key, + newkeys, + } => { + // At this point the exchange is complete + // and we're waiting for a KEWKEYS packet + let Some(input) = input else { + return Err(Error::KexInit); + }; + + if input.buffer.first() != Some(&msg::NEWKEYS) { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::Kex); + } + + Ok(KexProgress::Done { + newkeys, + server_host_key: Some(server_host_key), + }) + } + } + } +} + +fn compute_keys( + hash: CryptoVec, + kex: KexAlgorithm, + names: Names, + exchange: Exchange, + session_id: Option<&CryptoVec>, +) -> Result { + let session_id = if let Some(session_id) = session_id { + session_id + } else { + &hash + }; + // Now computing keys. + let c = kex.compute_keys( + session_id, + &hash, + names.cipher, + names.server_mac, + names.client_mac, + false, + )?; + Ok(NewKeys { + exchange, + names, + kex, + key: 0, + cipher: c, + session_id: session_id.clone(), + }) } diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 5cd442bb..4586d971 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -26,104 +26,78 @@ //! The [Session](client::Session) is passed to the [Handler](client::Handler) //! when the client receives data. //! -//! ```no_run -//! use async_trait::async_trait; -//! use std::sync::Arc; -//! use russh::*; -//! use russh::server::{Auth, Session}; -//! use russh_keys::*; -//! use futures::Future; -//! use std::io::Read; +//! Check out the following examples: //! -//! struct Client { -//! } -//! -//! #[async_trait] -//! impl client::Handler for Client { -//! type Error = anyhow::Error; -//! -//! async fn check_server_key(self, server_public_key: &key::PublicKey) -> Result<(Self, bool), Self::Error> { -//! println!("check_server_key: {:?}", server_public_key); -//! Ok((self, true)) -//! } -//! -//! async fn data(self, channel: ChannelId, data: &[u8], session: client::Session) -> Result<(Self, client::Session), Self::Error> { -//! println!("data on channel {:?}: {:?}", channel, std::str::from_utf8(data)); -//! Ok((self, session)) -//! } -//! } -//! -//! #[tokio::main] -//! async fn main() { -//! let config = russh::client::Config::default(); -//! let config = Arc::new(config); -//! let sh = Client{}; -//! -//! let key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); -//! let mut agent = russh_keys::agent::client::AgentClient::connect_env().await.unwrap(); -//! agent.add_identity(&key, &[]).await.unwrap(); -//! let mut session = russh::client::connect(config, ("127.0.0.1", 22), sh).await.unwrap(); -//! if session.authenticate_future(std::env::var("USER").unwrap_or("user".to_owned()), key.clone_public_key().unwrap(), agent).await.1.unwrap() { -//! let mut channel = session.channel_open_session().await.unwrap(); -//! channel.data(&b"Hello, world!"[..]).await.unwrap(); -//! if let Some(msg) = channel.wait().await { -//! println!("{:?}", msg) -//! } -//! } -//! } -//! ``` +//! * [Client that connects to a server, runs a command and prints its output](https://github.com/warp-tech/russh/blob/main/russh/examples/client_exec_simple.rs) +//! * [Client that connects to a server, runs a command in a PTY and provides interactive input/output](https://github.com/warp-tech/russh/blob/main/russh/examples/client_exec_interactive.rs) +//! * [SFTP client (with `russh-sftp`)](https://github.com/warp-tech/russh/blob/main/russh/examples/sftp_client.rs) //! //! [Session]: client::Session -use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::convert::TryInto; +use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; +#[cfg(not(target_arch = "wasm32"))] +use std::time::Duration; -use async_trait::async_trait; use futures::task::{Context, Poll}; use futures::Future; -use log::{debug, error, info, trace}; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Reader; -#[cfg(feature = "openssl")] -use russh_keys::key::SignatureHash; -use russh_keys::key::{self, parse_public_key, PublicKey}; -use tokio; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpStream, ToSocketAddrs}; +use kex::ClientKex; +use log::{debug, error, trace}; +use russh_util::time::Instant; +use ssh_encoding::Decode; +use ssh_key::{Algorithm, Certificate, HashAlg, PrivateKey, PublicKey}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::pin; use tokio::sync::mpsc::{ channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, }; +use tokio::sync::oneshot; -use crate::channels::{Channel, ChannelMsg}; -use crate::cipher::{self, clear, CipherPair, OpeningKey}; -use crate::key::PubKey; -use crate::session::{CommonSession, EncryptedState, Exchange, Kex, KexDhDone, KexInit, NewKeys}; +pub use crate::auth::AuthResult; +use crate::channels::{ + Channel, ChannelMsg, ChannelReadHalf, ChannelRef, ChannelWriteHalf, WindowSizeRef, +}; +use crate::cipher::{self, clear, OpeningKey}; +use crate::kex::{KexCause, KexProgress, SessionKexState}; +use crate::keys::PrivateKeyWithHashAlg; +use crate::msg::{is_kex_msg, validate_server_msg_strict_kex}; +use crate::session::{CommonSession, EncryptedState, GlobalRequestResponse, NewKeys}; use crate::ssh_read::SshRead; -use crate::sshbuffer::{SSHBuffer, SshId}; -use crate::{auth, msg, negotiation, ChannelId, ChannelOpenFailure, Disconnect, Limits, Sig}; +use crate::sshbuffer::{IncomingSshPacket, PacketWriter, SSHBuffer, SshId}; +use crate::{ + auth, map_err, msg, negotiation, ChannelId, ChannelOpenFailure, CryptoVec, Disconnect, Error, + Limits, MethodSet, Sig, +}; mod encrypted; mod kex; mod session; +#[cfg(test)] +mod test; + /// Actual client session's state. /// /// It is in charge of multiplexing and keeping track of various channels /// that may get opened and closed during the lifetime of an SSH session and /// allows sending messages to the server. +#[derive(Debug)] pub struct Session { + kex: SessionKexState, common: CommonSession>, receiver: Receiver, sender: UnboundedSender, - channels: HashMap>, + channels: HashMap, target_window_size: u32, pending_reads: Vec, pending_len: u32, inbound_channel_sender: Sender, inbound_channel_receiver: Receiver, + open_global_requests: VecDeque, + server_sig_algs: Option>, } impl Drop for Session { @@ -136,10 +110,13 @@ impl Drop for Session { #[allow(clippy::large_enum_variant)] enum Reply { AuthSuccess, - AuthFailure, + AuthFailure { + proceed_with_methods: MethodSet, + partial_success: bool, + }, ChannelOpenFailure, SignRequest { - key: key::PublicKey, + key: ssh_key::PublicKey, data: CryptoVec, }, AuthInfoRequest { @@ -162,34 +139,46 @@ pub enum Msg { data: CryptoVec, }, ChannelOpenSession { - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenX11 { originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectTcpIp { host_to_connect: String, port_to_connect: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectStreamLocal { socket_path: String, - sender: UnboundedSender, + channel_ref: ChannelRef, }, TcpIpForward { - want_reply: bool, + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>>, address: String, port: u32, }, CancelTcpIpForward { - want_reply: bool, + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, address: String, port: u32, }, + StreamLocalForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, + socket_path: String, + }, + CancelStreamLocalForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, + socket_path: String, + }, Close { id: ChannelId, }, @@ -199,6 +188,21 @@ pub enum Msg { language_tag: String, }, Channel(ChannelId, ChannelMsg), + Rekey, + AwaitExtensionInfo { + extension_name: String, + reply_channel: oneshot::Sender<()>, + }, + GetServerSigAlgs { + reply_channel: oneshot::Sender>>, + }, + /// Send a keepalive packet to the remote + Keepalive { + want_reply: bool, + }, + NoMoreSessions { + want_reply: bool, + }, } impl From<(ChannelId, ChannelMsg)> for Msg { @@ -210,7 +214,13 @@ impl From<(ChannelId, ChannelMsg)> for Msg { #[derive(Debug)] pub enum KeyboardInteractiveAuthResponse { Success, - Failure, + Failure { + /// The server suggests to proceed with these auth methods + remaining_methods: MethodSet, + /// The server says that though auth method has been accepted, + /// further authentication is required + partial_success: bool, + }, InfoRequest { name: String, instructions: String, @@ -224,12 +234,26 @@ pub struct Prompt { pub echo: bool, } +#[derive(Debug)] +pub struct RemoteDisconnectInfo { + pub reason_code: crate::Disconnect, + pub message: String, + pub lang_tag: String, +} + +#[derive(Debug)] +pub enum DisconnectReason + Send> { + ReceivedDisconnect(RemoteDisconnectInfo), + Error(E), +} + /// Handle to a session, used to send messages to a client outside of /// the request/response cycle. pub struct Handle { sender: Sender, receiver: UnboundedReceiver, - join: tokio::task::JoinHandle>, + join: russh_util::runtime::JoinHandle>, + channel_buffer_size: usize, } impl Drop for Handle { @@ -248,7 +272,7 @@ impl Handle { pub async fn authenticate_none>( &mut self, user: U, - ) -> Result { + ) -> Result { let user = user.into(); self.sender .send(Msg::Authenticate { @@ -265,7 +289,7 @@ impl Handle { &mut self, user: U, password: P, - ) -> Result { + ) -> Result { let user = user.into(); self.sender .send(Msg::Authenticate { @@ -281,7 +305,7 @@ impl Handle { /// Initiate Keyboard-Interactive based SSH authentication. /// - /// * `submethods` - Hnts to the server the preferred methods to be used for authentication + /// * `submethods` - Hints to the server the preferred methods to be used for authentication pub async fn authenticate_keyboard_interactive_start< U: Into, S: Into>, @@ -307,7 +331,7 @@ impl Handle { /// complete Keyboard-Interactive based SSH authentication. /// /// * `responses` - The responses to each prompt. The number of responses must match the number - /// of prompts. If a prompt has an empty string, then the response should be an empty string. + /// of prompts. If a prompt has an empty string, then the response should be an empty string. pub async fn authenticate_keyboard_interactive_respond( &mut self, responses: Vec, @@ -325,7 +349,15 @@ impl Handle { loop { match self.receiver.recv().await { Some(Reply::AuthSuccess) => return Ok(KeyboardInteractiveAuthResponse::Success), - Some(Reply::AuthFailure) => return Ok(KeyboardInteractiveAuthResponse::Failure), + Some(Reply::AuthFailure { + proceed_with_methods: remaining_methods, + partial_success, + }) => { + return Ok(KeyboardInteractiveAuthResponse::Failure { + remaining_methods, + partial_success, + }) + } Some(Reply::AuthInfoRequest { name, instructions, @@ -337,28 +369,48 @@ impl Handle { prompts, }); } + None => return Err(crate::Error::RecvError), _ => {} } } } - async fn wait_recv_reply(&mut self) -> Result { + async fn wait_recv_reply(&mut self) -> Result { loop { match self.receiver.recv().await { - Some(Reply::AuthSuccess) => return Ok(true), - Some(Reply::AuthFailure) => return Ok(false), - None => return Ok(false), + Some(Reply::AuthSuccess) => return Ok(AuthResult::Success), + Some(Reply::AuthFailure { + proceed_with_methods: remaining_methods, + partial_success, + }) => { + return Ok(AuthResult::Failure { + remaining_methods, + partial_success, + }) + } + None => { + return Ok(AuthResult::Failure { + remaining_methods: MethodSet::empty(), + partial_success: false, + }) + } _ => {} } } } /// Perform public key-based SSH authentication. + /// + /// For RSA keys, you'll need to decide on which hash algorithm to use. + /// This is the difference between what is also known as + /// `ssh-rsa`, `rsa-sha2-256`, and `rsa-sha2-512` "keys" in OpenSSH. + /// You can use [Handle::best_supported_rsa_hash] to automatically + /// figure out the best hash algorithm for RSA keys. pub async fn authenticate_publickey>( &mut self, user: U, - key: Arc, - ) -> Result { + key: PrivateKeyWithHashAlg, + ) -> Result { let user = user.into(); self.sender .send(Msg::Authenticate { @@ -370,45 +422,75 @@ impl Handle { self.wait_recv_reply().await } + /// Perform public OpenSSH Certificate-based SSH authentication + pub async fn authenticate_openssh_cert>( + &mut self, + user: U, + key: Arc, + cert: Certificate, + ) -> Result { + let user = user.into(); + self.sender + .send(Msg::Authenticate { + user, + method: auth::Method::OpenSshCertificate { key, cert }, + }) + .await + .map_err(|_| crate::Error::SendError)?; + self.wait_recv_reply().await + } + /// Authenticate using a custom method that implements the /// [`Signer`][auth::Signer] trait. Currently, this crate only provides an - /// implementation for an [SSH - /// agent][russh_keys::agent::client::AgentClient]. - pub async fn authenticate_future, S: auth::Signer>( + /// implementation for an [SSH agent][crate::keys::agent::client::AgentClient]. + pub async fn authenticate_publickey_with, S: auth::Signer>( &mut self, user: U, - key: key::PublicKey, - mut future: S, - ) -> (S, Result) { + key: ssh_key::PublicKey, + hash_alg: Option, + signer: &mut S, + ) -> Result { let user = user.into(); if self .sender .send(Msg::Authenticate { user, - method: auth::Method::FuturePublicKey { key }, + method: auth::Method::FuturePublicKey { key, hash_alg }, }) .await .is_err() { - return (future, Err((crate::SendError {}).into())); + return Err((crate::SendError {}).into()); } loop { let reply = self.receiver.recv().await; match reply { - Some(Reply::AuthSuccess) => return (future, Ok(true)), - Some(Reply::AuthFailure) => return (future, Ok(false)), + Some(Reply::AuthSuccess) => return Ok(AuthResult::Success), + Some(Reply::AuthFailure { + proceed_with_methods: remaining_methods, + partial_success, + }) => { + return Ok(AuthResult::Failure { + remaining_methods, + partial_success, + }) + } Some(Reply::SignRequest { key, data }) => { - let (f, data) = future.auth_publickey_sign(&key, data).await; - future = f; + let data = signer.auth_publickey_sign(&key, hash_alg, data).await; let data = match data { Ok(data) => data, - Err(e) => return (future, Err(e)), + Err(e) => return Err(e), }; if self.sender.send(Msg::Signed { data }).await.is_err() { - return (future, Err((crate::SendError {}).into())); + return Err((crate::SendError {}).into()); } } - None => return (future, Ok(false)), + None => { + return Ok(AuthResult::Failure { + remaining_methods: MethodSet::empty(), + partial_success: false, + }) + } _ => {} } } @@ -417,7 +499,8 @@ impl Handle { /// Wait for confirmation that a channel is open async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, + window_size_ref: WindowSizeRef, ) -> Result, crate::Error> { loop { match receiver.recv().await { @@ -426,18 +509,23 @@ impl Handle { max_packet_size, window_size, }) => { + window_size_ref.update(window_size).await; + return Ok(Channel { - id, - sender: self.sender.clone(), - receiver, - max_packet_size, - window_size, + write_half: ChannelWriteHalf { + id, + sender: self.sender.clone(), + max_packet_size, + window_size: window_size_ref, + }, + read_half: ChannelReadHalf { receiver }, }); } Some(ChannelMsg::OpenFailure(reason)) => { return Err(crate::Error::ChannelOpenFailure(reason)); } None => { + debug!("channel confirmation sender was dropped"); return Err(crate::Error::Disconnect); } msg => { @@ -447,18 +535,82 @@ impl Handle { } } + /// See [`Handle::best_supported_rsa_hash`]. + #[cfg(not(target_arch = "wasm32"))] + async fn await_extension_info(&self, extension_name: String) -> Result<(), crate::Error> { + let (sender, receiver) = oneshot::channel(); + self.sender + .send(Msg::AwaitExtensionInfo { + extension_name, + reply_channel: sender, + }) + .await + .map_err(|_| crate::Error::SendError)?; + let _ = tokio::time::timeout(Duration::from_secs(1), receiver).await; + Ok(()) + } + + /// Returns the best RSA hash algorithm supported by the server, + /// as indicated by the `server-sig-algs` extension. + /// If the server does not support the extension, + /// `None` is returned. In this case you may still attempt an authentication + /// with `rsa-sha2-256` or `rsa-sha2-512` and hope for the best. + /// If the server supports the extension, but does not support `rsa-sha2-*`, + /// `Some(None)` is returned. + /// + /// Note that this method will wait for up to 1 second for the server to + /// send the extension info if it hasn't done so yet (except when running under + /// WebAssembly). Unfortunately the timing of the EXT_INFO message cannot be known + /// in advance (RFC 8308). + /// + /// If this method returns `None` once, then for most SSH servers + /// you can assume that it will return `None` every time. + pub async fn best_supported_rsa_hash(&self) -> Result>, Error> { + // Wait for the extension info from the server + #[cfg(not(target_arch = "wasm32"))] + self.await_extension_info("server-sig-algs".into()).await?; + + let (sender, receiver) = oneshot::channel(); + + self.sender + .send(Msg::GetServerSigAlgs { + reply_channel: sender, + }) + .await + .map_err(|_| crate::Error::SendError)?; + + if let Some(ssa) = receiver.await.map_err(|_| Error::Inconsistent)? { + let possible_algs = [ + Some(ssh_key::HashAlg::Sha512), + Some(ssh_key::HashAlg::Sha256), + None, + ]; + for alg in possible_algs.into_iter() { + if ssa.contains(&Algorithm::Rsa { hash: alg }) { + return Ok(Some(alg)); + } + } + } + + Ok(None) + } + /// Request a session channel (the most basic type of /// channel). This function returns `Some(..)` immediately if the /// connection is authenticated, but the channel only becomes /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender - .send(Msg::ChannelOpenSession { sender }) + .send(Msg::ChannelOpenSession { channel_ref }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Request an X11 channel, on which the X11 protocol may be tunneled. @@ -467,16 +619,20 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenX11 { originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Open a TCP/IP forwarding channel. This is usually done when a @@ -494,68 +650,143 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectTcpIp { host_to_connect: host_to_connect.into(), port_to_connect, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_direct_streamlocal>( &self, socket_path: S, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectStreamLocal { socket_path: socket_path.into(), - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } + /// Requests the server to open a TCP/IP forward channel + /// + /// If port == 0 the server will choose a port that will be returned, returns 0 otherwise pub async fn tcpip_forward>( &mut self, address: A, port: u32, - ) -> Result { + ) -> Result { + let (reply_send, reply_recv) = oneshot::channel(); self.sender .send(Msg::TcpIpForward { - want_reply: true, + reply_channel: Some(reply_send), address: address.into(), port, }) .await .map_err(|_| crate::Error::SendError)?; - if port == 0 { - self.wait_recv_reply().await?; + + match reply_recv.await { + Ok(Some(port)) => Ok(port), + Ok(None) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive TcpIpForward result: {e:?}"); + Err(crate::Error::Disconnect) + } } - Ok(true) } + // Requests the server to close a TCP/IP forward channel pub async fn cancel_tcpip_forward>( &self, address: A, port: u32, - ) -> Result { + ) -> Result<(), crate::Error> { + let (reply_send, reply_recv) = oneshot::channel(); self.sender .send(Msg::CancelTcpIpForward { - want_reply: true, + reply_channel: Some(reply_send), address: address.into(), port, }) .await .map_err(|_| crate::Error::SendError)?; - Ok(true) + + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive CancelTcpIpForward result: {e:?}"); + Err(crate::Error::Disconnect) + } + } + } + + // Requests the server to open a UDS forward channel + pub async fn streamlocal_forward>( + &mut self, + socket_path: A, + ) -> Result<(), crate::Error> { + let (reply_send, reply_recv) = oneshot::channel(); + self.sender + .send(Msg::StreamLocalForward { + reply_channel: Some(reply_send), + socket_path: socket_path.into(), + }) + .await + .map_err(|_| crate::Error::SendError)?; + + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive StreamLocalForward result: {e:?}"); + Err(crate::Error::Disconnect) + } + } + } + + // Requests the server to close a UDS forward channel + pub async fn cancel_streamlocal_forward>( + &self, + socket_path: A, + ) -> Result<(), crate::Error> { + let (reply_send, reply_recv) = oneshot::channel(); + self.sender + .send(Msg::CancelStreamLocalForward { + reply_channel: Some(reply_send), + socket_path: socket_path.into(), + }) + .await + .map_err(|_| crate::Error::SendError)?; + + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(crate::Error::RequestDenied), + Err(e) => { + error!("Unable to receive CancelStreamLocalForward result: {e:?}"); + Err(crate::Error::Disconnect) + } + } } /// Sends a disconnect message. @@ -589,6 +820,32 @@ impl Handle { _ => unreachable!(), }) } + + /// Asynchronously perform a session re-key at the next opportunity + pub async fn rekey_soon(&self) -> Result<(), Error> { + self.sender + .send(Msg::Rekey) + .await + .map_err(|_| Error::SendError)?; + + Ok(()) + } + + /// Send a keepalive package to the remote peer. + pub async fn send_keepalive(&self, want_reply: bool) -> Result<(), Error> { + self.sender + .send(Msg::Keepalive { want_reply }) + .await + .map_err(|_| Error::SendError) + } + + /// Send a no-more-sessions request to the remote peer. + pub async fn no_more_sessions(&self, want_reply: bool) -> Result<(), Error> { + self.sender + .send(Msg::NoMoreSessions { want_reply }) + .await + .map_err(|_| Error::SendError) + } } impl Future for Handle { @@ -612,14 +869,13 @@ impl Future for Handle { /// commands, etc. The future will resolve to an error if the connection fails. /// This function creates a connection to the `addr` specified using a /// [`tokio::net::TcpStream`] and then calls [`connect_stream`] under the hood. -pub async fn connect( +#[cfg(not(target_arch = "wasm32"))] +pub async fn connect( config: Arc, addrs: A, handler: H, ) -> Result, H::Error> { - let socket = TcpStream::connect(addrs) - .await - .map_err(crate::Error::from)?; + let socket = map_err!(tokio::net::TcpStream::connect(addrs).await)?; connect_stream(config, socket, handler).await } @@ -638,15 +894,16 @@ where { // Writing SSH id. let mut write_buffer = SSHBuffer::new(); + + debug!("ssh id = {:?}", config.as_ref().client_id); + write_buffer.send_ssh_id(&config.as_ref().client_id); - stream - .write_all(&write_buffer.buffer) - .await - .map_err(crate::Error::from)?; + map_err!(stream.write_all(&write_buffer.buffer).await)?; // Reading SSH id and allocating a session if correct. let mut stream = SshRead::new(stream); let sshid = stream.read_ssh_id().await?; + let (handle_sender, session_receiver) = channel(10); let (session_sender, handle_receiver) = unbounded_channel(); if config.maximum_packet_size > 65535 { @@ -655,32 +912,36 @@ where config.maximum_packet_size ); } + let channel_buffer_size = config.channel_buffer_size; let mut session = Session::new( config.window_size, CommonSession { - write_buffer, - kex: None, + packet_writer: PacketWriter::clear(), auth_user: String::new(), auth_attempts: 0, auth_method: None, // Client only. - cipher: CipherPair { - local_to_remote: Box::new(clear::Key), - remote_to_local: Box::new(clear::Key), - }, + remote_to_local: Box::new(clear::Key), encrypted: None, config, wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + strict_kex: false, + alive_timeouts: 0, + received_data: false, + remote_sshid: sshid.into(), }, session_receiver, session_sender, ); - session.read_ssh_id(sshid)?; - let (encrypted_signal, encrypted_recv) = tokio::sync::oneshot::channel(); - let join = tokio::spawn(session.run(stream, handler, Some(encrypted_signal))); - - if encrypted_recv.await.is_err() { + session.begin_rekey()?; + let (kex_done_signal, kex_done_signal_rx) = oneshot::channel(); + let join = russh_util::runtime::spawn(session.run(stream, handler, Some(kex_done_signal))); + + if let Err(err) = kex_done_signal_rx.await { + // kex_done_signal Sender is dropped when the session + // fails before a succesful key exchange + debug!("kex_done_signal sender was dropped {err:?}"); join.await.map_err(crate::Error::Join)??; return Err(H::Error::from(crate::Error::Disconnect)); } @@ -689,6 +950,7 @@ where sender: handle_sender, receiver: handle_receiver, join, + channel_buffer_size, }) } @@ -703,6 +965,26 @@ async fn start_reading( } impl Session { + fn maybe_decompress(&mut self, buffer: &SSHBuffer) -> Result { + if let Some(ref mut enc) = self.common.encrypted { + let mut decomp = CryptoVec::new(); + Ok(IncomingSshPacket { + #[allow(clippy::indexing_slicing)] // length checked + buffer: enc.decompress.decompress( + &buffer.buffer[5..], + &mut decomp, + )?.into(), + seqn: buffer.seqn, + }) + } else { + Ok(IncomingSshPacket { + #[allow(clippy::indexing_slicing)] // length checked + buffer: buffer.buffer[5..].into(), + seqn: buffer.seqn, + }) + } + } + fn new( target_window_size: u32, common: CommonSession>, @@ -714,87 +996,138 @@ impl Session { common, receiver, sender, + kex: SessionKexState::Idle, target_window_size, inbound_channel_sender, inbound_channel_receiver, channels: HashMap::new(), pending_reads: Vec::new(), pending_len: 0, + open_global_requests: VecDeque::new(), + server_sig_algs: None, } } async fn run( mut self, - mut stream: SshRead, + stream: SshRead, mut handler: H, - mut encrypted_signal: Option>, + mut kex_done_signal: Option>, ) -> Result<(), H::Error> { - self.flush()?; - if !self.common.write_buffer.buffer.is_empty() { - debug!("writing {:?} bytes", self.common.write_buffer.buffer.len()); - stream - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; - stream.flush().await.map_err(crate::Error::from)?; + let (stream_read, mut stream_write) = stream.split(); + let result = self + .run_inner( + stream_read, + &mut stream_write, + &mut handler, + &mut kex_done_signal, + ) + .await; + trace!("disconnected"); + self.receiver.close(); + self.inbound_channel_receiver.close(); + map_err!(stream_write.shutdown().await)?; + match result { + Ok(v) => { + handler + .disconnected(DisconnectReason::ReceivedDisconnect(v)) + .await?; + Ok(()) + } + Err(e) => { + if kex_done_signal.is_some() { + // The kex signal has not been consumed yet, + // so we can send return the concrete error to be propagated + // into the JoinHandle and returned from `connect_stream` + Err(e) + } else { + // The kex signal has been consumed, so no one is + // awaiting the result of this coroutine + // We're better off passing the error into the Handler + debug!("disconnected {e:?}"); + handler.disconnected(DisconnectReason::Error(e)).await?; + Err(H::Error::from(crate::Error::Disconnect)) + } + } } - self.common.write_buffer.buffer.clear(); - let mut decomp = CryptoVec::new(); + } + + async fn run_inner( + &mut self, + stream_read: SshRead>, + stream_write: &mut WriteHalf, + handler: &mut H, + kex_done_signal: &mut Option>, + ) -> Result { + let mut result: Result = Err(Error::Disconnect.into()); + self.flush()?; + + map_err!(self.common.packet_writer.flush_into(stream_write).await)?; - let (stream_read, mut stream_write) = stream.split(); let buffer = SSHBuffer::new(); // Allow handing out references to the cipher let mut opening_cipher = Box::new(clear::Key) as Box; - std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local); + + let keepalive_timer = + crate::future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); + pin!(keepalive_timer); + + let inactivity_timer = + crate::future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); + pin!(inactivity_timer); let reading = start_reading(stream_read, buffer, opening_cipher); pin!(reading); #[allow(clippy::panic)] // false positive in select! macro while !self.common.disconnected { + self.common.received_data = false; + let mut sent_keepalive = false; tokio::select! { r = &mut reading => { - let (stream_read, buffer, mut opening_cipher) = match r { + let (stream_read, mut buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), Err(e) => return Err(e.into()) }; - std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local); if buffer.buffer.len() < 5 { break } - let buf = if let Some(ref mut enc) = self.common.encrypted { + let mut pkt = self.maybe_decompress(&buffer)?; + if !pkt.buffer.is_empty() { #[allow(clippy::indexing_slicing)] // length checked - if let Ok(buf) = enc.decompress.decompress( - &buffer.buffer[5..], - &mut decomp, - ) { - buf + if pkt.buffer[0] == crate::msg::DISCONNECT { + debug!("received disconnect"); + result = self.process_disconnect(&pkt).map_err(H::Error::from); } else { - break - } - } else { - #[allow(clippy::indexing_slicing)] // length checked - &buffer.buffer[5..] - }; - if !buf.is_empty() { - #[allow(clippy::indexing_slicing)] // length checked - if buf[0] == crate::msg::DISCONNECT { - break; - } else if buf[0] > 4 { - let (h, s) = reply(self, handler, &mut encrypted_signal, buf).await?; - handler = h; - self = s; + self.common.received_data = true; + reply(self, handler, kex_done_signal, &mut pkt).await?; + buffer.seqn = pkt.seqn; // TODO reply changes seqn internall, find cleaner way } } - std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local); reading.set(start_reading(stream_read, buffer, opening_cipher)); } - msg = self.receiver.recv(), if !self.is_rekeying() => { + () = &mut keepalive_timer => { + if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { + debug!("Timeout, server not responding to keepalives"); + return Err(crate::Error::KeepaliveTimeout.into()); + } + self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); + self.send_keepalive(true)?; + sent_keepalive = true; + } + () = &mut inactivity_timer => { + debug!("timeout"); + return Err(crate::Error::InactivityTimeout.into()); + } + msg = self.receiver.recv(), if !self.kex.active() => { match msg { Some(msg) => self.handle_msg(msg)?, None => { @@ -804,80 +1137,113 @@ impl Session { }; // eagerly take all outgoing messages so writes are batched - while !self.is_rekeying() { + while !self.kex.active() { match self.receiver.try_recv() { Ok(next) => self.handle_msg(next)?, Err(_) => break } } } - msg = self.inbound_channel_receiver.recv(), if !self.is_rekeying() => { + msg = self.inbound_channel_receiver.recv(), if !self.kex.active() => { match msg { Some(msg) => self.handle_msg(msg)?, None => (), } // eagerly take all outgoing messages so writes are batched - while !self.is_rekeying() { + while !self.kex.active() { match self.inbound_channel_receiver.try_recv() { Ok(next) => self.handle_msg(next)?, Err(_) => break } } } - } + }; + self.flush()?; - if !self.common.write_buffer.buffer.is_empty() { - trace!( - "writing to stream: {:?} bytes", - self.common.write_buffer.buffer.len() - ); - stream_write - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; - stream_write.flush().await.map_err(crate::Error::from)?; - } - self.common.write_buffer.buffer.clear(); + map_err!(self.common.packet_writer.flush_into(stream_write).await)?; + if let Some(ref mut enc) = self.common.encrypted { if let EncryptedState::InitCompression = enc.state { - enc.client_compression.init_compress(&mut enc.compress); + enc.client_compression + .init_compress(self.common.packet_writer.compress()); enc.state = EncryptedState::Authenticated; } } + + if self.common.received_data { + // Reset the number of failed keepalive attempts. We don't + // bother detecting keepalive response messages specifically + // (OpenSSH_9.6p1 responds with REQUEST_FAILURE aka 82). Instead + // we assume that the server is still alive if we receive any + // data from it. + self.common.alive_timeouts = 0; + } + if self.common.received_data || sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + keepalive_timer.as_mut().as_pin_mut(), + self.common.config.keepalive_interval, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } + if !sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + inactivity_timer.as_mut().as_pin_mut(), + self.common.config.inactivity_timeout, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } } - debug!("disconnected"); - if self.common.disconnected { - stream_write.shutdown().await.map_err(crate::Error::from)?; - } - Ok(()) + + result + } + + fn process_disconnect( + &mut self, + pkt: &IncomingSshPacket, + ) -> Result { + let mut r = &pkt.buffer[..]; + u8::decode(&mut r)?; // skip message type + self.common.disconnected = true; + + let reason_code = u32::decode(&mut r)?.try_into()?; + let message = String::decode(&mut r)?; + let lang_tag = String::decode(&mut r)?; + + Ok(RemoteDisconnectInfo { + reason_code, + message, + lang_tag, + }) } fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> { match msg { Msg::Authenticate { user, method } => { - self.write_auth_request_if_needed(&user, method); + self.write_auth_request_if_needed(&user, method)?; } Msg::Signed { .. } => {} Msg::AuthInfoResponse { .. } => {} - Msg::ChannelOpenSession { sender } => { + Msg::ChannelOpenSession { channel_ref } => { let id = self.channel_open_session()?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenX11 { originator_address, originator_port, - sender, + channel_ref, } => { let id = self.channel_open_x11(&originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, - sender, + channel_ref, } => { let id = self.channel_open_direct_tcpip( &host_to_connect, @@ -885,36 +1251,44 @@ impl Session { &originator_address, originator_port, )?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenDirectStreamLocal { socket_path, - sender, + channel_ref, } => { let id = self.channel_open_direct_streamlocal(&socket_path)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::TcpIpForward { - want_reply, + reply_channel, address, port, - } => self.tcpip_forward(want_reply, &address, port), + } => self.tcpip_forward(reply_channel, &address, port)?, Msg::CancelTcpIpForward { - want_reply, + reply_channel, address, port, - } => self.cancel_tcpip_forward(want_reply, &address, port), + } => self.cancel_tcpip_forward(reply_channel, &address, port)?, + Msg::StreamLocalForward { + reply_channel, + socket_path, + } => self.streamlocal_forward(reply_channel, &socket_path)?, + Msg::CancelStreamLocalForward { + reply_channel, + socket_path, + } => self.cancel_streamlocal_forward(reply_channel, &socket_path)?, Msg::Disconnect { reason, description, language_tag, - } => self.disconnect(reason, &description, &language_tag), - Msg::Channel(id, ChannelMsg::Data { data }) => self.data(id, data), + } => self.disconnect(reason, &description, &language_tag)?, + Msg::Channel(id, ChannelMsg::Data { data }) => self.data(id, data)?, Msg::Channel(id, ChannelMsg::Eof) => { - self.eof(id); + self.eof(id)?; } Msg::Channel(id, ChannelMsg::ExtendedData { data, ext }) => { - self.extended_data(id, ext, data); + self.extended_data(id, ext, data)?; } Msg::Channel( id, @@ -936,7 +1310,7 @@ impl Session { pix_width, pix_height, &terminal_modes, - ), + )?, Msg::Channel( id, ChannelMsg::WindowChange { @@ -945,7 +1319,7 @@ impl Session { pix_width, pix_height, }, - ) => self.window_change(id, col_width, row_height, pix_width, pix_height), + ) => self.window_change(id, col_width, row_height, pix_width, pix_height)?, Msg::Channel( id, ChannelMsg::RequestX11 { @@ -962,7 +1336,7 @@ impl Session { &x11_authentication_protocol, &x11_authentication_cookie, x11_screen_number, - ), + )?, Msg::Channel( id, ChannelMsg::SetEnv { @@ -970,9 +1344,9 @@ impl Session { variable_name, variable_value, }, - ) => self.set_env(id, want_reply, &variable_name, &variable_value), + ) => self.set_env(id, want_reply, &variable_name, &variable_value)?, Msg::Channel(id, ChannelMsg::RequestShell { want_reply }) => { - self.request_shell(want_reply, id) + self.request_shell(want_reply, id)? } Msg::Channel( id, @@ -980,15 +1354,43 @@ impl Session { want_reply, command, }, - ) => self.exec(id, want_reply, &command), - Msg::Channel(id, ChannelMsg::Signal { signal }) => self.signal(id, signal), + ) => self.exec(id, want_reply, &command)?, + Msg::Channel(id, ChannelMsg::Signal { signal }) => self.signal(id, signal)?, Msg::Channel(id, ChannelMsg::RequestSubsystem { want_reply, name }) => { - self.request_subsystem(want_reply, id, &name) + self.request_subsystem(want_reply, id, &name)? } Msg::Channel(id, ChannelMsg::AgentForward { want_reply }) => { - self.agent_forward(id, want_reply) + self.agent_forward(id, want_reply)? + } + Msg::Channel(id, ChannelMsg::Close) => self.close(id)?, + Msg::Rekey => self.initiate_rekey()?, + Msg::AwaitExtensionInfo { + extension_name, + reply_channel, + } => { + if let Some(ref mut enc) = self.common.encrypted { + // Drop if the extension has been seen already + if !enc.received_extensions.contains(&extension_name) { + // There will be no new extension info after authentication + // has succeeded + if !matches!(enc.state, EncryptedState::Authenticated) { + enc.extension_info_awaiters + .entry(extension_name) + .or_insert(vec![]) + .push(reply_channel); + } + } + } + } + Msg::GetServerSigAlgs { reply_channel } => { + let _ = reply_channel.send(self.server_sig_algs.clone()); + } + Msg::Keepalive { want_reply } => { + let _ = self.send_keepalive(want_reply); + } + Msg::NoMoreSessions { want_reply } => { + let _ = self.no_more_sessions(want_reply); } - Msg::Channel(id, ChannelMsg::Close) => self.close(id), msg => { // should be unreachable, since the receiver only gets // messages from methods implemented within russh @@ -998,35 +1400,23 @@ impl Session { Ok(()) } - fn is_rekeying(&self) -> bool { - if let Some(ref enc) = self.common.encrypted { - enc.rekey.is_some() - } else { - true - } - } + fn begin_rekey(&mut self) -> Result<(), crate::Error> { + debug!("beginning re-key"); + let mut kex = ClientKex::new( + self.common.config.clone(), + &self.common.config.client_id, + &self.common.remote_sshid, + match &self.common.encrypted { + None => KexCause::Initial, + Some(enc) => KexCause::Rekey { + strict: self.common.strict_kex, + session_id: enc.session_id.clone(), + }, + }, + ); - fn read_ssh_id(&mut self, sshid: &[u8]) -> Result<(), crate::Error> { - // self.read_buffer.bytes += sshid.bytes_read + 2; - let mut exchange = Exchange::new(); - exchange.server_id.extend(sshid); - // Preparing the response - exchange - .client_id - .extend(self.common.config.client_id.as_kex_hash_bytes()); - let mut kexinit = KexInit { - exchange, - algo: None, - sent: false, - session_id: None, - }; - self.common.write_buffer.buffer.clear(); - kexinit.client_write( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - &mut self.common.write_buffer, - )?; - self.common.kex = Some(Kex::Init(kexinit)); + kex.kexinit(&mut self.common.packet_writer)?; + self.kex = SessionKexState::InProgress(kex); Ok(()) } @@ -1036,183 +1426,124 @@ impl Session { if let Some(ref mut enc) = self.common.encrypted { if enc.flush( &self.common.config.as_ref().limits, - &mut *self.common.cipher.local_to_remote, - &mut self.common.write_buffer, - )? { - info!("Re-exchanging keys"); - if enc.rekey.is_none() { - if let Some(exchange) = std::mem::replace(&mut enc.exchange, None) { - let mut kexinit = KexInit::initiate_rekey(exchange, &enc.session_id); - kexinit.client_write( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - &mut self.common.write_buffer, - )?; - enc.rekey = Some(Kex::Init(kexinit)) - } - } + &mut self.common.packet_writer, + )? && !self.kex.active() + { + self.begin_rekey()?; } } Ok(()) } - /// Send a `ChannelMsg` from the background handler to the client. - pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool { - if let Some(chan) = self.channels.get(&channel) { - chan.send(msg).unwrap_or(()); - true - } else { - false + /// Immediately trigger a session re-key after flushing all pending packets + pub fn initiate_rekey(&mut self) -> Result<(), Error> { + if let Some(ref mut enc) = self.common.encrypted { + enc.rekey_wanted = true; + self.flush()? } + Ok(()) } } -thread_local! { - static HASH_BUFFER: RefCell = RefCell::new(CryptoVec::new()); -} +async fn reply( + session: &mut Session, + handler: &mut H, + kex_done_signal: &mut Option>, + pkt: &mut IncomingSshPacket, +) -> Result<(), H::Error> { + if let Some(message_type) = pkt.buffer.first() { + debug!( + "< msg type {message_type:?}, seqn {:?}, len {}", + pkt.seqn.0, + pkt.buffer.len() + ); + if session.common.strict_kex && session.common.encrypted.is_none() { + let seqno = pkt.seqn.0 - 1; // was incremented after read() + validate_server_msg_strict_kex(*message_type, seqno as usize)?; + } -impl KexDhDone { - async fn server_key_check( - mut self, - rekey: bool, - mut handler: H, - buf: &[u8], - ) -> Result<(NewKeys, H), H::Error> { - let mut reader = buf.reader(1); - let pubkey = reader.read_string().map_err(crate::Error::from)?; // server public key. - let pubkey = parse_public_key( - pubkey, - #[cfg(feature = "openssl")] - SignatureHash::from_rsa_hostkey_algo(self.names.key.0.as_bytes()), - ) - .map_err(crate::Error::from)?; - debug!("server_public_Key: {:?}", pubkey); - if !rekey { - let ret = handler.check_server_key(&pubkey).await?; - handler = ret.0; - let check = ret.1; - if !check { - return Err(crate::Error::UnknownKey.into()); - } + if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) { + return Ok(()); } - HASH_BUFFER.with(|buffer| { - let mut buffer = buffer.borrow_mut(); - buffer.clear(); - let hash = { - let server_ephemeral = reader.read_string().map_err(crate::Error::from)?; - self.exchange.server_ephemeral.extend(server_ephemeral); - let signature = reader.read_string().map_err(crate::Error::from)?; - - self.kex - .compute_shared_secret(&self.exchange.server_ephemeral)?; - debug!("kexdhdone.exchange = {:?}", self.exchange); - - let mut pubkey_vec = CryptoVec::new(); - pubkey.push_to(&mut pubkey_vec); - - let hash = - self.kex - .compute_exchange_hash(&pubkey_vec, &self.exchange, &mut buffer)?; - - debug!("exchange hash: {:?}", hash); - let signature = { - let mut sig_reader = signature.reader(0); - let sig_type = sig_reader.read_string().map_err(crate::Error::from)?; - debug!("sig_type: {:?}", sig_type); - sig_reader.read_string().map_err(crate::Error::from)? - }; - use russh_keys::key::Verify; - debug!("signature: {:?}", signature); - if !pubkey.verify_server_auth(hash.as_ref(), signature) { - debug!("wrong server sig"); - return Err(crate::Error::WrongServerSig.into()); - } - hash - }; - let mut newkeys = self.compute_keys(hash, false)?; - newkeys.sent = true; - Ok((newkeys, handler)) - }) } -} -async fn reply( - mut session: Session, - mut handler: H, - sender: &mut Option>, - buf: &[u8], -) -> Result<(H, Session), H::Error> { - match session.common.kex.take() { - Some(Kex::Init(kexinit)) => { - if kexinit.algo.is_some() - || buf.first() == Some(&msg::KEXINIT) - || session.common.encrypted.is_none() - { - let done = kexinit.client_parse( - session.common.config.as_ref(), - &mut *session.common.cipher.local_to_remote, - buf, - &mut session.common.write_buffer, - )?; + if pkt.buffer.first() == Some(&msg::KEXINIT) && session.kex == SessionKexState::Idle { + // Not currently in a rekey but received KEXINIT + debug!("server has initiated re-key"); + session.begin_rekey()?; + // Kex will consume the packet right away + } + + let is_kex_msg = pkt.buffer.first().cloned().map(is_kex_msg).unwrap_or(false); - if done.kex.skip_exchange() { - session.common.encrypted( - initial_encrypted_state(&session), - done.compute_keys(CryptoVec::new(), false)?, - ); + if is_kex_msg { + if let SessionKexState::InProgress(kex) = session.kex.take() { + let progress = kex.step(Some(pkt), &mut session.common.packet_writer)?; - if let Some(sender) = sender.take() { - sender.send(()).unwrap_or(()); + match progress { + KexProgress::NeedsReply { kex, reset_seqn } => { + debug!("kex impl continues: {kex:?}"); + session.kex = SessionKexState::InProgress(kex); + if reset_seqn { + debug!("kex impl requests seqno reset"); + session.common.reset_seqn(); } - } else { - session.common.kex = Some(Kex::DhDone(done)); } - session.flush()?; - } - Ok((handler, session)) - } - Some(Kex::DhDone(mut kexdhdone)) => { - if kexdhdone.names.ignore_guessed { - kexdhdone.names.ignore_guessed = false; - session.common.kex = Some(Kex::DhDone(kexdhdone)); - Ok((handler, session)) - } else if buf.first() == Some(&msg::KEX_ECDH_REPLY) { - // We've sent ECDH_INIT, waiting for ECDH_REPLY - let (kex, h) = kexdhdone.server_key_check(false, handler, buf).await?; - handler = h; - session.common.kex = Some(Kex::Keys(kex)); - session - .common - .cipher - .local_to_remote - .write(&[msg::NEWKEYS], &mut session.common.write_buffer); - session.flush()?; - Ok((handler, session)) - } else { - error!("Wrong packet received"); - Err(crate::Error::Inconsistent.into()) - } - } - Some(Kex::Keys(newkeys)) => { - debug!("newkeys received"); - if buf.first() != Some(&msg::NEWKEYS) { - return Err(crate::Error::Kex.into()); - } - if let Some(sender) = sender.take() { - sender.send(()).unwrap_or(()); + KexProgress::Done { + server_host_key, + newkeys, + } => { + debug!("kex impl has completed"); + session.common.strict_kex = + session.common.strict_kex || newkeys.names.strict_kex; + + if let Some(ref mut enc) = session.common.encrypted { + // This is a rekey + enc.last_rekey = Instant::now(); + session.common.packet_writer.buffer().bytes = 0; + enc.flush_all_pending()?; + let mut pending = std::mem::take(&mut session.pending_reads); + for p in pending.drain(..) { + session.process_packet(handler, &p).await?; + } + session.pending_reads = pending; + session.pending_len = 0; + session.common.newkeys(newkeys); + } else { + // This is the initial kex + if let Some(server_host_key) = &server_host_key { + let check = handler.check_server_key(server_host_key).await?; + if !check { + return Err(crate::Error::UnknownKey.into()); + } + } + + session + .common + .encrypted(initial_encrypted_state(session), newkeys); + + if let Some(sender) = kex_done_signal.take() { + sender.send(()).unwrap_or(()); + } + } + + session.kex = SessionKexState::Idle; + + if session.common.strict_kex { + pkt.seqn = Wrapping(0); + } + + debug!("kex done"); + } } - session - .common - .encrypted(initial_encrypted_state(&session), newkeys); - // Ok, NEWKEYS received, now encrypted. - Ok((handler, session)) - } - Some(kex) => { - session.common.kex = Some(kex); - Ok((handler, session)) + + session.flush()?; + + return Ok(()); } - None => session.client_read_encrypted(handler, buf).await, } + + session.client_read_encrypted(handler, pkt).await } fn initial_encrypted_state(session: &Session) -> EncryptedState { @@ -1226,6 +1557,74 @@ fn initial_encrypted_state(session: &Session) -> EncryptedState { } } +/// Parameters for dynamic group Diffie-Hellman key exchanges. +#[derive(Debug, Clone)] +pub struct GexParams { + /// Minimum DH group size (in bits) + min_group_size: usize, + /// Preferred DH group size (in bits) + preferred_group_size: usize, + /// Maximum DH group size (in bits) + max_group_size: usize, +} + +impl GexParams { + pub fn new( + min_group_size: usize, + preferred_group_size: usize, + max_group_size: usize, + ) -> Result { + let this = Self { + min_group_size, + preferred_group_size, + max_group_size, + }; + this.validate()?; + Ok(this) + } + + pub(crate) fn validate(&self) -> Result<(), Error> { + if self.min_group_size < 2048 { + return Err(Error::InvalidConfig( + "min_group_size must be at least 2048 bits".into(), + )); + } + if self.preferred_group_size < self.min_group_size { + return Err(Error::InvalidConfig( + "preferred_group_size must be at least as large as min_group_size".into(), + )); + } + if self.max_group_size < self.preferred_group_size { + return Err(Error::InvalidConfig( + "max_group_size must be at least as large as preferred_group_size".into(), + )); + } + Ok(()) + } + + pub fn min_group_size(&self) -> usize { + self.min_group_size + } + + pub fn preferred_group_size(&self) -> usize { + self.preferred_group_size + } + + pub fn max_group_size(&self) -> usize { + self.max_group_size + } +} + +impl Default for GexParams { + fn default() -> GexParams { + GexParams { + min_group_size: 3072, + preferred_group_size: 8192, + max_group_size: 8192, + } + } +} + /// The configuration of clients. #[derive(Debug)] pub struct Config { @@ -1237,12 +1636,20 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) + pub channel_buffer_size: usize, /// Lists of preferred algorithms. pub preferred: negotiation::Preferred, /// Time after which the connection is garbage-collected. pub inactivity_timeout: Option, + /// If nothing is received from the server for this amount of time, send a keepalive message. + pub keepalive_interval: Option, + /// If this many keepalives have been sent without reply, close the connection. + pub keepalive_max: usize, /// Whether to expect and wait for an authentication call. pub anonymous: bool, + /// DH dynamic group exchange parameters. + pub gex: GexParams, } impl Default for Config { @@ -1256,9 +1663,13 @@ impl Default for Config { limits: Limits::default(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, preferred: Default::default(), inactivity_timeout: None, + keepalive_interval: None, + keepalive_max: 3, anonymous: false, + gex: Default::default(), } } } @@ -1266,171 +1677,209 @@ impl Default for Config { /// A client handler. Note that messages can be received from the /// server at any time during a session. /// -/// Note: this is an `async_trait`. Click `[source]` on the right to see actual async function definitions. - -#[async_trait] +/// You must at the very least implement the `check_server_key` fn. +/// The default implementation rejects all keys. +/// +/// Note: this is an async trait. The trait functions return `impl Future`, +/// and you can simply define them as `async fn` instead. +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] pub trait Handler: Sized + Send { - type Error: From + Send; + type Error: From + Send + core::fmt::Debug; /// Called when the server sends us an authentication banner. This /// is usually meant to be shown to the user, see /// [RFC4252](https://tools.ietf.org/html/rfc4252#section-5.4) for /// more details. - /// - /// The returned Boolean is ignored. #[allow(unused_variables)] - async fn auth_banner( - self, + fn auth_banner( + &mut self, banner: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called to check the server's public key. This is a very important /// step to help prevent man-in-the-middle attacks. The default /// implementation rejects all keys. #[allow(unused_variables)] - async fn check_server_key( - self, - server_public_key: &key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, false)) + fn check_server_key( + &mut self, + server_public_key: &ssh_key::PublicKey, + ) -> impl Future> + Send { + async { Ok(false) } } /// Called when the server confirmed our request to open a /// channel. A channel can only be written to after receiving this /// message (this library panics otherwise). #[allow(unused_variables)] - async fn channel_open_confirmation( - self, + fn channel_open_confirmation( + &mut self, id: ChannelId, max_packet_size: u32, window_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server signals success. #[allow(unused_variables)] - async fn channel_success( - self, + fn channel_success( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server signals failure. #[allow(unused_variables)] - async fn channel_failure( - self, + fn channel_failure( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server closes a channel. #[allow(unused_variables)] - async fn channel_close( - self, + fn channel_close( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server sends EOF to a channel. #[allow(unused_variables)] - async fn channel_eof( - self, + fn channel_eof( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server rejected our request to open a channel. #[allow(unused_variables)] - async fn channel_open_failure( - self, + fn channel_open_failure( + &mut self, channel: ChannelId, reason: ChannelOpenFailure, description: &str, language: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server opens a channel for a new remote port forwarding connection #[allow(unused_variables)] - async fn server_channel_open_forwarded_tcpip( - self, + fn server_channel_open_forwarded_tcpip( + &mut self, channel: Channel, connected_address: &str, connected_port: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } + } + + // Called when the server opens a channel for a new remote UDS forwarding connection + #[allow(unused_variables)] + fn server_channel_open_forwarded_streamlocal( + &mut self, + channel: Channel, + socket_path: &str, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server opens an agent forwarding channel #[allow(unused_variables)] - async fn server_channel_open_agent_forward( - self, - channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + fn server_channel_open_agent_forward( + &mut self, + channel: Channel, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } + } + + /// Called when the server attempts to open a channel of unknown type. It may return `true`, + /// if the channel of unknown type should be accepted. In this case, + /// [Handler::server_channel_open_unknown] will be called soon after. If it returns `false`, + /// the channel will not be created and a rejection message will be sent to the server. + #[allow(unused_variables)] + fn should_accept_unknown_server_channel( + &mut self, + id: ChannelId, + channel_type: &str, + ) -> impl Future + Send { + async { false } } - /// Called when the server gets an unknown channel. It may return `true`, - /// if the channel of unknown type should be handled. If it returns `false`, - /// the channel will not be created and an error will be sent to the server. + /// Called when the server opens an unknown channel. #[allow(unused_variables)] - fn server_channel_handle_unknown(&self, channel: ChannelId, channel_type: &[u8]) -> bool { - false + fn server_channel_open_unknown( + &mut self, + channel: Channel, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server opens a session channel. #[allow(unused_variables)] - async fn server_channel_open_session( - self, - channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + fn server_channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } - /// Called when the server opens a direct tcp/ip channel. + /// Called when the server opens a direct tcp/ip channel (non-standard). #[allow(unused_variables)] - async fn server_channel_open_direct_tcpip( - self, - channel: ChannelId, + fn server_channel_open_direct_tcpip( + &mut self, + channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } + } + + /// Called when the server opens a direct-streamlocal channel (non-standard). + #[allow(unused_variables)] + fn server_channel_open_direct_streamlocal( + &mut self, + channel: Channel, + socket_path: &str, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server opens an X11 channel. #[allow(unused_variables)] - async fn server_channel_open_x11( - self, + fn server_channel_open_x11( + &mut self, channel: Channel, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server sends us data. The `extended_code` @@ -1438,13 +1887,13 @@ pub trait Handler: Sized + Send { /// standard output, and `Some(1)` is the standard error. See /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2). #[allow(unused_variables)] - async fn data( - self, + fn data( + &mut self, channel: ChannelId, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the server sends us data. The `extended_code` @@ -1452,52 +1901,52 @@ pub trait Handler: Sized + Send { /// standard output, and `Some(1)` is the standard error. See /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2). #[allow(unused_variables)] - async fn extended_data( - self, + fn extended_data( + &mut self, channel: ChannelId, ext: u32, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The server informs this client of whether the client may /// perform control-S/control-Q flow control. See /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.8). #[allow(unused_variables)] - async fn xon_xoff( - self, + fn xon_xoff( + &mut self, channel: ChannelId, client_can_do: bool, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The remote process has exited, with the given exit status. #[allow(unused_variables)] - async fn exit_status( - self, + fn exit_status( + &mut self, channel: ChannelId, exit_status: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The remote process exited upon receiving a signal. #[allow(unused_variables)] - async fn exit_signal( - self, + fn exit_signal( + &mut self, channel: ChannelId, signal_name: Sig, core_dumped: bool, error_message: &str, lang_tag: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the network window is adjusted, meaning that we @@ -1506,13 +1955,13 @@ pub trait Handler: Sized + Send { /// `Session::data` before, and it returned less than the /// full amount of data. #[allow(unused_variables)] - async fn window_adjusted( - self, + fn window_adjusted( + &mut self, channel: ChannelId, new_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when this client adjusts the network window. Return the @@ -1524,12 +1973,31 @@ pub trait Handler: Sized + Send { /// Called when the server signals success. #[allow(unused_variables)] - async fn openssh_ext_host_keys_announced( - self, + fn openssh_ext_host_keys_announced( + &mut self, keys: Vec, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - debug!("openssh_ext_hostkeys_announced: {:?}", keys); - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async move { + debug!("openssh_ext_hostkeys_announced: {:?}", keys); + Ok(()) + } + } + + /// Called when the server sent a disconnect message + /// + /// If reason is an Error, this function should re-return the error so the join can also evaluate it + #[allow(unused_variables)] + fn disconnected( + &mut self, + reason: DisconnectReason, + ) -> impl Future> + Send { + async { + debug!("disconnected: {:?}", reason); + match reason { + DisconnectReason::ReceivedDisconnect(_) => Ok(()), + DisconnectReason::Error(e) => Err(e), + } + } } } diff --git a/russh/src/client/session.rs b/russh/src/client/session.rs index adc87da9..3ca7a09f 100644 --- a/russh/src/client/session.rs +++ b/russh/src/client/session.rs @@ -1,10 +1,10 @@ -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; use log::error; +use ssh_encoding::Encode; +use tokio::sync::oneshot; use crate::client::Session; use crate::session::EncryptedState; -use crate::{msg, ChannelId, Disconnect, Pty, Sig}; +use crate::{map_err, msg, ChannelId, CryptoVec, Disconnect, Pty, Sig}; impl Session { fn channel_open_generic( @@ -13,7 +13,7 @@ impl Session { write_suffix: F, ) -> Result where - F: FnOnce(&mut CryptoVec), + F: FnOnce(&mut CryptoVec) -> Result<(), crate::Error>, { let result = if let Some(ref mut enc) = self.common.encrypted { match enc.state { @@ -23,21 +23,27 @@ impl Session { self.common.config.maximum_packet_size, ); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_OPEN); - enc.write.extend_ssh_string(kind); + msg::CHANNEL_OPEN.encode(&mut enc.write)?; + kind.encode(&mut enc.write)?; // sender channel id. - enc.write.push_u32_be(sender_channel.0); + sender_channel.encode(&mut enc.write)?; // window. - enc.write - .push_u32_be(self.common.config.as_ref().window_size); + self.common + .config + .as_ref() + .window_size + .encode(&mut enc.write)?; // max packet size. - enc.write - .push_u32_be(self.common.config.as_ref().maximum_packet_size); + self.common + .config + .as_ref() + .maximum_packet_size + .encode(&mut enc.write)?; - write_suffix(&mut enc.write); + write_suffix(&mut enc.write)?; }); sender_channel } @@ -50,7 +56,7 @@ impl Session { } pub fn channel_open_session(&mut self) -> Result { - self.channel_open_generic(b"session", |_| ()) + self.channel_open_generic(b"session", |_| Ok(())) } pub fn channel_open_x11( @@ -59,8 +65,9 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"x11", |write| { - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + map_err!(originator_address.encode(write))?; + map_err!(originator_port.encode(write))?; // sender channel id. + Ok(()) }) } @@ -72,21 +79,23 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"direct-tcpip", |write| { - write.extend_ssh_string(host_to_connect.as_bytes()); - write.push_u32_be(port_to_connect); // sender channel id. - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + host_to_connect.encode(write)?; + port_to_connect.encode(write)?; // sender channel id. + originator_address.encode(write)?; + originator_port.encode(write)?; // sender channel id. + Ok(()) }) } pub fn channel_open_direct_streamlocal( &mut self, - socket_path: &str + socket_path: &str, ) -> Result { self.channel_open_generic(b"direct-streamlocal@openssh.com", |write| { - write.extend_ssh_string(socket_path.as_bytes()); - write.extend_ssh_string("".as_bytes()); // reserved - write.push_u32_be(0); // reserved + socket_path.encode(write)?; + "".encode(write)?; // reserved + 0u32.encode(write)?; // reserved + Ok(()) }) } @@ -101,32 +110,33 @@ impl Session { pix_width: u32, pix_height: u32, terminal_modes: &[(Pty, u32)], - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + map_err!(msg::CHANNEL_REQUEST.encode(&mut enc.write))?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"pty-req"); - enc.write.push(want_reply as u8); + channel.recipient_channel.encode(&mut enc.write)?; + "pty-req".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; - enc.write.extend_ssh_string(term.as_bytes()); - enc.write.push_u32_be(col_width); - enc.write.push_u32_be(row_height); - enc.write.push_u32_be(pix_width); - enc.write.push_u32_be(pix_height); + term.encode(&mut enc.write)?; + col_width.encode(&mut enc.write)?; + row_height.encode(&mut enc.write)?; + pix_width.encode(&mut enc.write)?; + pix_height.encode(&mut enc.write)?; - enc.write.push_u32_be((1 + 5 * terminal_modes.len()) as u32); + ((1 + 5 * terminal_modes.len()) as u32).encode(&mut enc.write)?; for &(code, value) in terminal_modes { - enc.write.push(code as u8); - enc.write.push_u32_be(value) + (code as u8).encode(&mut enc.write)?; + value.encode(&mut enc.write)?; } // 0 code (to terminate the list) - enc.write.push(0); + 0u8.encode(&mut enc.write)?; }); } } + Ok(()) } pub fn request_x11( @@ -137,24 +147,23 @@ impl Session { x11_authentication_protocol: &str, x11_authentication_cookie: &str, x11_screen_number: u32, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"x11-req"); + channel.recipient_channel.encode(&mut enc.write)?; + "x11-req".encode(&mut enc.write)?; enc.write.push(want_reply as u8); enc.write.push(single_connection as u8); - enc.write - .extend_ssh_string(x11_authentication_protocol.as_bytes()); - enc.write - .extend_ssh_string(x11_authentication_cookie.as_bytes()); - enc.write.push_u32_be(x11_screen_number); + x11_authentication_protocol.encode(&mut enc.write)?; + x11_authentication_cookie.encode(&mut enc.write)?; + x11_screen_number.encode(&mut enc.write)?; }); } } + Ok(()) } pub fn set_env( @@ -163,80 +172,99 @@ impl Session { want_reply: bool, variable_name: &str, variable_value: &str, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"env"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(variable_name.as_bytes()); - enc.write.extend_ssh_string(variable_value.as_bytes()); + channel.recipient_channel.encode(&mut enc.write)?; + "env".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + variable_name.encode(&mut enc.write)?; + variable_value.encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn request_shell(&mut self, want_reply: bool, channel: ChannelId) { + pub fn request_shell( + &mut self, + want_reply: bool, + channel: ChannelId, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"shell"); - enc.write.push(want_reply as u8); + channel.recipient_channel.encode(&mut enc.write)?; + "shell".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn exec(&mut self, channel: ChannelId, want_reply: bool, command: &[u8]) { + pub fn exec( + &mut self, + channel: ChannelId, + want_reply: bool, + command: &[u8], + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"exec"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(command); + channel.recipient_channel.encode(&mut enc.write)?; + "exec".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + command.encode(&mut enc.write)?; }); - return; + return Ok(()); } } error!("exec"); + Ok(()) } - pub fn signal(&mut self, channel: ChannelId, signal: Sig) { + pub fn signal(&mut self, channel: ChannelId, signal: Sig) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"signal"); - enc.write.push(0); - enc.write.extend_ssh_string(signal.name().as_bytes()); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + channel.recipient_channel.encode(&mut enc.write)?; + "signal".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + signal.name().encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn request_subsystem(&mut self, want_reply: bool, channel: ChannelId, name: &str) { + pub fn request_subsystem( + &mut self, + want_reply: bool, + channel: ChannelId, + name: &str, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"subsystem"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(name.as_bytes()); + channel.recipient_channel.encode(&mut enc.write)?; + "subsystem".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + name.encode(&mut enc.write)?; }); } } + Ok(()) } pub fn window_change( @@ -246,57 +274,168 @@ impl Session { row_height: u32, pix_width: u32, pix_height: u32, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"window-change"); - enc.write.push(0); // this packet never wants reply - enc.write.push_u32_be(col_width); - enc.write.push_u32_be(row_height); - enc.write.push_u32_be(pix_width); - enc.write.push_u32_be(pix_height); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + + channel.recipient_channel.encode(&mut enc.write)?; + "window-change".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + col_width.encode(&mut enc.write)?; + row_height.encode(&mut enc.write)?; + pix_width.encode(&mut enc.write)?; + pix_height.encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn tcpip_forward(&mut self, want_reply: bool, address: &str, port: u32) { + /// Requests a TCP/IP forwarding from the server + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// [`Some`] for a success message with port, or [`None`] for failure + pub fn tcpip_forward( + &mut self, + reply_channel: Option>>, + address: &str, + port: u32, + ) -> Result<(), crate::Error> { + if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::TcpIpForward(reply_channel), + ); + } + push_packet!(enc.write, { + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; + }); + } + Ok(()) + } + + /// Requests cancellation of TCP/IP forwarding from the server + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// `true` for a success message, or `false` for failure + pub fn cancel_tcpip_forward( + &mut self, + reply_channel: Option>, + address: &str, + port: u32, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::CancelTcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"tcpip-forward"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "cancel-tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } - pub fn cancel_tcpip_forward(&mut self, want_reply: bool, address: &str, port: u32) { + /// Requests a UDS forwarding from the server, `socket path` being the server side socket path. + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// `true` for a success message, or `false` for failure + pub fn streamlocal_forward( + &mut self, + reply_channel: Option>, + socket_path: &str, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::StreamLocalForward(reply_channel), + ); + } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"cancel-tcpip-forward"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "streamlocal-forward@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + socket_path.encode(&mut enc.write)?; }); } + Ok(()) } - pub fn data(&mut self, channel: ChannelId, data: CryptoVec) { + /// Requests cancellation of UDS forwarding from the server + /// + /// If `reply_channel` is not None, sets want_reply and returns the server's response via the channel, + /// `true` for a success message and `false` for failure. + pub fn cancel_streamlocal_forward( + &mut self, + reply_channel: Option>, + socket_path: &str, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { - enc.data(channel, data) + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::CancelStreamLocalForward(reply_channel), + ); + } + push_packet!(enc.write, { + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "cancel-streamlocal-forward@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + socket_path.encode(&mut enc.write)?; + }); + } + Ok(()) + } + + pub fn send_keepalive(&mut self, want_reply: bool) -> Result<(), crate::Error> { + self.open_global_requests + .push_back(crate::session::GlobalRequestResponse::Keepalive); + if let Some(ref mut enc) = self.common.encrypted { + push_packet!(enc.write, { + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "keepalive@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + }); + } + Ok(()) + } + + pub fn no_more_sessions(&mut self, want_reply: bool) -> Result<(), crate::Error> { + self.open_global_requests + .push_back(crate::session::GlobalRequestResponse::NoMoreSessions); + if let Some(ref mut enc) = self.common.encrypted { + push_packet!(enc.write, { + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "no-more-sessions@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + }); + } + Ok(()) + } + + pub fn data(&mut self, channel: ChannelId, data: CryptoVec) -> Result<(), crate::Error> { + if let Some(ref mut enc) = self.common.encrypted { + enc.data(channel, data, self.kex.active()) } else { unreachable!() } } - pub fn eof(&mut self, channel: ChannelId) { + pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { enc.eof(channel) } else { @@ -304,7 +443,7 @@ impl Session { } } - pub fn close(&mut self, channel: ChannelId) { + pub fn close(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { enc.close(channel) } else { @@ -312,29 +451,44 @@ impl Session { } } - pub fn extended_data(&mut self, channel: ChannelId, ext: u32, data: CryptoVec) { + pub fn extended_data( + &mut self, + channel: ChannelId, + ext: u32, + data: CryptoVec, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { - enc.extended_data(channel, ext, data) + enc.extended_data(channel, ext, data, self.kex.active()) } else { unreachable!() } } - pub fn agent_forward(&mut self, channel: ChannelId, want_reply: bool) { + pub fn agent_forward( + &mut self, + channel: ChannelId, + want_reply: bool, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"auth-agent-req@openssh.com"); - enc.write.push(want_reply as u8); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + channel.recipient_channel.encode(&mut enc.write)?; + "auth-agent-req@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) { - self.common.disconnect(reason, description, language_tag); + pub fn disconnect( + &mut self, + reason: Disconnect, + description: &str, + language_tag: &str, + ) -> Result<(), crate::Error> { + self.common.disconnect(reason, description, language_tag) } pub fn has_pending_data(&self, channel: ChannelId) -> bool { @@ -352,4 +506,17 @@ impl Session { 0 } } + + /// Returns the SSH ID (Protocol Version + Software Version) the server sent when connecting + /// + /// This should contain only ASCII characters for implementations conforming to RFC4253, Section 4.2: + /// + /// > Both the 'protoversion' and 'softwareversion' strings MUST consist of + /// > printable US-ASCII characters, with the exception of whitespace + /// > characters and the minus sign (-). + /// + /// So it usually is fine to convert it to a `String` using `String::from_utf8_lossy` + pub fn remote_sshid(&self) -> &[u8] { + &self.common.remote_sshid + } } diff --git a/russh/src/client/test.rs b/russh/src/client/test.rs new file mode 100644 index 00000000..0591170e --- /dev/null +++ b/russh/src/client/test.rs @@ -0,0 +1,161 @@ +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + use log::debug; + use rand_core::OsRng; + use ssh_key::PrivateKey; + use tokio::net::TcpListener; + + // Import client types directly since we're in the client module + use crate::client::{connect, Config, Handler}; + use crate::keys::PrivateKeyWithHashAlg; + use crate::server::{self, Auth, Handler as ServerHandler, Server, Session}; + use crate::{ChannelId, SshId}; // Import directly from crate root + use crate::{CryptoVec, Error}; + + #[derive(Clone)] + struct TestServer { + clients: Arc>>, + id: usize, + } + + impl server::Server for TestServer { + type Handler = Self; + + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } + } + + impl ServerHandler for TestServer { + type Error = Error; + + async fn channel_open_session( + &mut self, + channel: crate::channels::Channel, + session: &mut Session, + ) -> Result { + { + let mut clients = self.clients.lock().unwrap(); + clients.insert((self.id, channel.id()), session.handle()); + } + Ok(true) + } + + async fn auth_publickey( + &mut self, + _: &str, + _: &ssh_key::PublicKey, + ) -> Result { + debug!("auth_publickey"); + Ok(Auth::Accept) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + debug!("server received data: {:?}", std::str::from_utf8(data)); + session.data(channel, CryptoVec::from_slice(data))?; + Ok(()) + } + } + + struct Client {} + + impl Handler for Client { + type Error = Error; + + async fn check_server_key(&mut self, _: &ssh_key::PublicKey) -> Result { + Ok(true) + } + } + + #[tokio::test] + async fn test_client_connects_to_protocol_1_99() { + let _ = env_logger::try_init(); + + // Create a client key + let client_key = PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(); + + // Configure the server + let mut config = server::Config::default(); + config.auth_rejection_time = std::time::Duration::from_secs(1); + config.server_id = SshId::Standard("SSH-1.99-CustomServer_1.0".to_string()); + config.inactivity_timeout = None; + config + .keys + .push(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + let config = Arc::new(config); + + // Create server struct + let mut server = TestServer { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + }; + + // Start the TCP listener for our mock server + let socket = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + + // Spawn a separate task that will handle the server connection + tokio::spawn(async move { + // Accept a connection + let (socket, _) = socket.accept().await.unwrap(); + + // Handle the connection with the server + let server_handler = server.new_client(None); + server::run_stream(config, socket, server_handler) + .await + .unwrap(); + }); + + println!("Server listening on {}", addr); + + // Configure the client + let client_config = Arc::new(Config::default()); + + // Connect to the server + let mut session = connect(client_config, addr, Client {}).await.unwrap(); + + // Unfortunately, we can't directly verify the protocol version from the client API + // The Protocol199Stream wrapper ensures the server sends SSH-1.99-CustomServer_1.0 + // The test passing means the client accepted this protocol version + + // Try to authenticate + let auth_result = session + .authenticate_publickey( + std::env::var("USER").unwrap_or("user".to_string()), + PrivateKeyWithHashAlg::new( + Arc::new(client_key), + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) + .await + .unwrap(); + + assert!(auth_result.success()); + + // Try opening a session channel + let mut channel = session.channel_open_session().await.unwrap(); + + // Send some data + let test_data = b"Hello, 1.99 protocol server!"; + channel.data(&test_data[..]).await.unwrap(); + + // Wait for response + let msg = channel.wait().await.unwrap(); + match msg { + crate::channels::ChannelMsg::Data { data: msg_data } => { + assert_eq!(test_data.as_slice(), &msg_data[..]); + } + msg => panic!("Unexpected message {:?}", msg), + } + } +} diff --git a/russh/src/compression.rs b/russh/src/compression.rs index 20aff2cc..d6eec087 100644 --- a/russh/src/compression.rs +++ b/russh/src/compression.rs @@ -1,4 +1,9 @@ -#[derive(Debug)] +use std::convert::TryFrom; + +use delegate::delegate; +use ssh_encoding::Encode; + +#[derive(Debug, Clone)] pub enum Compression { None, #[cfg(feature = "flate2")] @@ -19,10 +24,50 @@ pub enum Decompress { Zlib(flate2::Decompress), } +#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] +pub struct Name(&'static str); +impl AsRef for Name { + fn as_ref(&self) -> &str { + self.0 + } +} + +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + ALL_COMPRESSION_ALGORITHMS + .iter() + .find(|x| x.0 == s) + .map(|x| **x) + .ok_or(()) + } +} + +pub const NONE: Name = Name("none"); +#[cfg(feature = "flate2")] +pub const ZLIB: Name = Name("zlib"); +#[cfg(feature = "flate2")] +pub const ZLIB_LEGACY: Name = Name("zlib@openssh.com"); + +pub const ALL_COMPRESSION_ALGORITHMS: &[&Name] = &[ + &NONE, + #[cfg(feature = "flate2")] + &ZLIB, + #[cfg(feature = "flate2")] + &ZLIB_LEGACY, +]; + #[cfg(feature = "flate2")] impl Compression { - pub fn from_string(s: &str) -> Self { - if s == "zlib" || s == "zlib@openssh.com" { + pub fn new(name: &Name) -> Self { + if name == &ZLIB || name == &ZLIB_LEGACY { Compression::Zlib } else { Compression::None @@ -56,7 +101,7 @@ impl Compression { #[cfg(not(feature = "flate2"))] impl Compression { - pub fn from_string(_: &str) -> Self { + pub fn new(_name: &Name) -> Self { Compression::None } @@ -71,7 +116,7 @@ impl Compress { &mut self, input: &'a [u8], _: &'a mut russh_cryptovec::CryptoVec, - ) -> Result<&'a [u8], Error> { + ) -> Result<&'a [u8], crate::Error> { Ok(input) } } @@ -82,7 +127,7 @@ impl Decompress { &mut self, input: &'a [u8], _: &'a mut russh_cryptovec::CryptoVec, - ) -> Result<&'a [u8], Error> { + ) -> Result<&'a [u8], crate::Error> { Ok(input) } } diff --git a/russh/src/helpers.rs b/russh/src/helpers.rs new file mode 100644 index 00000000..b54c5baa --- /dev/null +++ b/russh/src/helpers.rs @@ -0,0 +1,124 @@ +use std::fmt::Debug; + +use ssh_encoding::{Decode, Encode}; +use ssh_key::private::KeypairData; +use ssh_key::Algorithm; + +#[doc(hidden)] +pub trait EncodedExt { + fn encoded(&self) -> ssh_key::Result>; +} + +impl EncodedExt for E { + fn encoded(&self) -> ssh_key::Result> { + let mut buf = Vec::new(); + self.encode(&mut buf)?; + Ok(buf) + } +} + +pub struct NameList(pub Vec); + +impl Debug for NameList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl NameList { + pub fn as_encoded_string(&self) -> String { + self.0.join(",") + } + + pub fn from_encoded_string(value: &str) -> Self { + Self(value.split(',').map(|x| x.to_string()).collect()) + } +} + +impl Encode for NameList { + fn encoded_len(&self) -> Result { + self.as_encoded_string().encoded_len() + } + + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error> { + self.as_encoded_string().encode(writer) + } +} + +impl Decode for NameList { + fn decode(reader: &mut impl ssh_encoding::Reader) -> Result { + let s = String::decode(reader)?; + Ok(Self::from_encoded_string(&s)) + } + + type Error = ssh_encoding::Error; +} + +#[macro_export] +#[doc(hidden)] +#[allow(clippy::crate_in_macro_def)] +macro_rules! map_err { + ($result:expr) => { + $result.map_err(|e| crate::Error::from(e)) + }; +} + +pub use map_err; + +#[doc(hidden)] +pub fn sign_with_hash_alg(key: &PrivateKeyWithHashAlg, data: &[u8]) -> ssh_key::Result> { + Ok(match key.key_data() { + KeypairData::Rsa(rsa_keypair) => { + let Algorithm::Rsa { hash } = key.algorithm() else { + unreachable!(); + }; + signature::Signer::try_sign(&(rsa_keypair, hash), data)?.encoded()? + } + keypair => signature::Signer::try_sign(keypair, data)?.encoded()?, + }) +} + +mod algorithm { + use ssh_key::{Algorithm, HashAlg}; + + pub trait AlgorithmExt { + fn hash_alg(&self) -> Option; + fn with_hash_alg(&self, hash_alg: Option) -> Self; + fn new_certificate_ext(algo: &str) -> Result + where + Self: Sized; + } + + impl AlgorithmExt for Algorithm { + fn hash_alg(&self) -> Option { + match self { + Algorithm::Rsa { hash } => *hash, + _ => None, + } + } + + fn with_hash_alg(&self, hash_alg: Option) -> Self { + match self { + Algorithm::Rsa { .. } => Algorithm::Rsa { hash: hash_alg }, + x => x.clone(), + } + } + + fn new_certificate_ext(algo: &str) -> Result { + match algo { + "rsa-sha2-256-cert-v01@openssh.com" => Ok(Algorithm::Rsa { + hash: Some(HashAlg::Sha256), + }), + "rsa-sha2-512-cert-v01@openssh.com" => Ok(Algorithm::Rsa { + hash: Some(HashAlg::Sha512), + }), + x => Algorithm::new_certificate(x), + } + } + } +} + +#[doc(hidden)] +pub use algorithm::AlgorithmExt; + +use crate::keys::key::PrivateKeyWithHashAlg; diff --git a/russh/src/kex/curve25519.rs b/russh/src/kex/curve25519.rs index 772b668e..3ad78963 100644 --- a/russh/src/kex/curve25519.rs +++ b/russh/src/kex/curve25519.rs @@ -3,22 +3,23 @@ use curve25519_dalek::constants::ED25519_BASEPOINT_TABLE; use curve25519_dalek::montgomery::MontgomeryPoint; use curve25519_dalek::scalar::Scalar; use log::debug; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; +use ssh_encoding::{Encode, Writer}; -use super::{compute_keys, KexAlgorithm, KexType}; +use super::{compute_keys, KexAlgorithm, KexAlgorithmImplementor, KexType}; +use crate::kex::encode_mpint; use crate::mac::{self}; use crate::session::Exchange; -use crate::{cipher, msg}; +use crate::{cipher, msg, CryptoVec}; pub struct Curve25519KexType {} impl KexType for Curve25519KexType { - fn make(&self) -> Box { - Box::new(Curve25519Kex { + fn make(&self) -> KexAlgorithm { + Curve25519Kex { local_secret: None, shared_secret: None, - }) as Box + } + .into() } } @@ -40,7 +41,7 @@ impl std::fmt::Debug for Curve25519Kex { // We used to support curve "NIST P-256" here, but the security of // that curve is controversial, see // http://safecurves.cr.yp.to/rigid.html -impl KexAlgorithm for Curve25519Kex { +impl KexAlgorithmImplementor for Curve25519Kex { fn skip_exchange(&self) -> bool { false } @@ -86,7 +87,7 @@ impl KexAlgorithm for Curve25519Kex { fn client_dh( &mut self, client_ephemeral: &mut CryptoVec, - buf: &mut CryptoVec, + writer: &mut impl Writer, ) -> Result<(), crate::Error> { let client_secret = Scalar::from_bytes_mod_order(rand::random::<[u8; 32]>()); let client_pubkey = (ED25519_BASEPOINT_TABLE * &client_secret).to_montgomery(); @@ -95,17 +96,15 @@ impl KexAlgorithm for Curve25519Kex { client_ephemeral.clear(); client_ephemeral.extend(&client_pubkey.0); - buf.push(msg::KEX_ECDH_INIT); - buf.extend_ssh_string(&client_pubkey.0); + msg::KEX_ECDH_INIT.encode(writer)?; + client_pubkey.0.encode(writer)?; self.local_secret = Some(client_secret); Ok(()) } fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), crate::Error> { - let local_secret = - std::mem::replace(&mut self.local_secret, None).ok_or(crate::Error::KexInit)?; - + let local_secret = self.local_secret.take().ok_or(crate::Error::KexInit)?; let mut remote_pubkey = MontgomeryPoint([0; 32]); remote_pubkey.0.clone_from_slice(remote_pubkey_); let shared = local_secret * remote_pubkey; @@ -121,17 +120,17 @@ impl KexAlgorithm for Curve25519Kex { ) -> Result { // Computing the exchange hash, see page 7 of RFC 5656. buffer.clear(); - buffer.extend_ssh_string(&exchange.client_id); - buffer.extend_ssh_string(&exchange.server_id); - buffer.extend_ssh_string(&exchange.client_kex_init); - buffer.extend_ssh_string(&exchange.server_kex_init); + exchange.client_id.encode(buffer)?; + exchange.server_id.encode(buffer)?; + exchange.client_kex_init.encode(buffer)?; + exchange.server_kex_init.encode(buffer)?; buffer.extend(key); - buffer.extend_ssh_string(&exchange.client_ephemeral); - buffer.extend_ssh_string(&exchange.server_ephemeral); + exchange.client_ephemeral.encode(buffer)?; + exchange.server_ephemeral.encode(buffer)?; if let Some(ref shared) = self.shared_secret { - buffer.extend_ssh_mpint(&shared.0); + encode_mpint(&shared.0, buffer)?; } use sha2::Digest; diff --git a/russh/src/kex/dh/groups.rs b/russh/src/kex/dh/groups.rs index 56248c9c..58259c5f 100644 --- a/russh/src/kex/dh/groups.rs +++ b/russh/src/kex/dh/groups.rs @@ -1,16 +1,71 @@ +use std::fmt::Debug; +use std::ops::Deref; + use hex_literal::hex; use num_bigint::{BigUint, RandBigInt}; use rand; +#[derive(Clone)] +pub enum DhGroupUInt { + Static(&'static [u8]), + Owned(Vec), +} + +impl From> for DhGroupUInt { + fn from(x: Vec) -> Self { + Self::Owned(x) + } +} + +impl DhGroupUInt { + pub const fn new(x: &'static [u8]) -> Self { + Self::Static(x) + } +} + +impl Deref for DhGroupUInt { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + match self { + Self::Static(x) => x, + Self::Owned(x) => x, + } + } +} + +#[derive(Clone)] pub struct DhGroup { - pub(crate) prime: &'static [u8], - pub(crate) generator: usize, - pub(crate) exp_size: u64, + pub(crate) prime: DhGroupUInt, + pub(crate) generator: DhGroupUInt, + // pub(crate) exp_size: u64, +} + +impl DhGroup { + pub fn bit_size(&self) -> usize { + let Some(fsb_idx) = self.prime.deref().iter().position(|&x| x != 0) else { + return 0; + }; + (self.prime.deref().len() - fsb_idx) * 8 + } +} + +impl Debug for DhGroup { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DhGroup") + .field("prime", &format!("<{} bytes>", self.prime.deref().len())) + .field( + "generator", + &format!("<{} bytes>", self.generator.deref().len()), + ) + .finish() + } } pub const DH_GROUP1: DhGroup = DhGroup { - prime: hex!( - " + prime: DhGroupUInt::new( + hex!( + " FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 @@ -18,15 +73,17 @@ pub const DH_GROUP1: DhGroup = DhGroup { EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE65381 FFFFFFFF FFFFFFFF " - ) - .as_slice(), - generator: 2, - exp_size: 256, + ) + .as_slice(), + ), + generator: DhGroupUInt::new(&[2]), + // exp_size: 256, }; pub const DH_GROUP14: DhGroup = DhGroup { - prime: hex!( - " + prime: DhGroupUInt::new( + hex!( + " FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 @@ -39,17 +96,174 @@ pub const DH_GROUP14: DhGroup = DhGroup { DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 15728E5A 8AACAA68 FFFFFFFF FFFFFFFF " - ) - .as_slice(), - generator: 2, - exp_size: 256, + ) + .as_slice(), + ), + generator: DhGroupUInt::new(&[2]), + // exp_size: 256, +}; + +/// https://www.ietf.org/rfc/rfc3526.txt +pub const DH_GROUP15: DhGroup = DhGroup { + prime: DhGroupUInt::new( + hex!( + " + FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 + 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD + EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 + E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED + EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D + C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F + 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D + 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B + E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 + DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 + 15728E5A 8AAAC42D AD33170D 04507A33 A85521AB DF1CBA64 + ECFB8504 58DBEF0A 8AEA7157 5D060C7D B3970F85 A6E1E4C7 + ABF5AE8C DB0933D7 1E8C94E0 4A25619D CEE3D226 1AD2EE6B + F12FFA06 D98A0864 D8760273 3EC86A64 521F2B18 177B200C + BBE11757 7A615D6C 770988C0 BAD946E2 08E24FA0 74E5AB31 + 43DB5BFC E0FD108E 4B82D120 A93AD2CA FFFFFFFF FFFFFFFF + " + ) + .as_slice(), + ), + generator: DhGroupUInt::new(&[2]), +}; + +pub const DH_GROUP16: DhGroup = DhGroup { + prime: DhGroupUInt::new( + hex!( + " + FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 + 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD + EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 + E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED + EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D + C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F + 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D + 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B + E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 + DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 + 15728E5A 8AAAC42D AD33170D 04507A33 A85521AB DF1CBA64 + ECFB8504 58DBEF0A 8AEA7157 5D060C7D B3970F85 A6E1E4C7 + ABF5AE8C DB0933D7 1E8C94E0 4A25619D CEE3D226 1AD2EE6B + F12FFA06 D98A0864 D8760273 3EC86A64 521F2B18 177B200C + BBE11757 7A615D6C 770988C0 BAD946E2 08E24FA0 74E5AB31 + 43DB5BFC E0FD108E 4B82D120 A9210801 1A723C12 A787E6D7 + 88719A10 BDBA5B26 99C32718 6AF4E23C 1A946834 B6150BDA + 2583E9CA 2AD44CE8 DBBBC2DB 04DE8EF9 2E8EFC14 1FBECAA6 + 287C5947 4E6BC05D 99B2964F A090C3A2 233BA186 515BE7ED + 1F612970 CEE2D7AF B81BDD76 2170481C D0069127 D5B05AA9 + 93B4EA98 8D8FDDC1 86FFB7DC 90A6C08F 4DF435C9 34063199 + FFFFFFFF FFFFFFFF + " + ) + .as_slice(), + ), + generator: DhGroupUInt::new(&[2]), + // exp_size: 512, +}; + +/// https://www.ietf.org/rfc/rfc3526.txt +pub const DH_GROUP17: DhGroup = DhGroup { + prime: DhGroupUInt::new( + hex!( + " + FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 29024E08 + 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD EF9519B3 CD3A431B + 302B0A6D F25F1437 4FE1356D 6D51C245 E485B576 625E7EC6 F44C42E9 + A637ED6B 0BFF5CB6 F406B7ED EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 + 49286651 ECE45B3D C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 + FD24CF5F 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D + 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B E39E772C + 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 DE2BCBF6 95581718 + 3995497C EA956AE5 15D22618 98FA0510 15728E5A 8AAAC42D AD33170D + 04507A33 A85521AB DF1CBA64 ECFB8504 58DBEF0A 8AEA7157 5D060C7D + B3970F85 A6E1E4C7 ABF5AE8C DB0933D7 1E8C94E0 4A25619D CEE3D226 + 1AD2EE6B F12FFA06 D98A0864 D8760273 3EC86A64 521F2B18 177B200C + BBE11757 7A615D6C 770988C0 BAD946E2 08E24FA0 74E5AB31 43DB5BFC + E0FD108E 4B82D120 A9210801 1A723C12 A787E6D7 88719A10 BDBA5B26 + 99C32718 6AF4E23C 1A946834 B6150BDA 2583E9CA 2AD44CE8 DBBBC2DB + 04DE8EF9 2E8EFC14 1FBECAA6 287C5947 4E6BC05D 99B2964F A090C3A2 + 233BA186 515BE7ED 1F612970 CEE2D7AF B81BDD76 2170481C D0069127 + D5B05AA9 93B4EA98 8D8FDDC1 86FFB7DC 90A6C08F 4DF435C9 34028492 + 36C3FAB4 D27C7026 C1D4DCB2 602646DE C9751E76 3DBA37BD F8FF9406 + AD9E530E E5DB382F 413001AE B06A53ED 9027D831 179727B0 865A8918 + DA3EDBEB CF9B14ED 44CE6CBA CED4BB1B DB7F1447 E6CC254B 33205151 + 2BD7AF42 6FB8F401 378CD2BF 5983CA01 C64B92EC F032EA15 D1721D03 + F482D7CE 6E74FEF6 D55E702F 46980C82 B5A84031 900B1C9E 59E7C97F + BEC7E8F3 23A97A7E 36CC88BE 0F1D45B7 FF585AC5 4BD407B2 2B4154AA + CC8F6D7E BF48E1D8 14CC5ED2 0F8037E0 A79715EE F29BE328 06A1D58B + B7C5DA76 F550AA3D 8A1FBFF0 EB19CCB1 A313D55C DA56C9EC 2EF29632 + 387FE8D7 6E3C0468 043E8F66 3F4860EE 12BF2D5B 0B7474D6 E694F91E + 6DCC4024 FFFFFFFF FFFFFFFF + " + ) + .as_slice(), + ), + generator: DhGroupUInt::new(&[2]), +}; + +/// https://www.ietf.org/rfc/rfc3526.txt +pub const DH_GROUP18: DhGroup = DhGroup { + prime: DhGroupUInt::new( + hex!( + " + FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 + 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD + EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 + E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED + EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D + C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F + 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D + 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B + E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 + DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 + 15728E5A 8AAAC42D AD33170D 04507A33 A85521AB DF1CBA64 + ECFB8504 58DBEF0A 8AEA7157 5D060C7D B3970F85 A6E1E4C7 + ABF5AE8C DB0933D7 1E8C94E0 4A25619D CEE3D226 1AD2EE6B + F12FFA06 D98A0864 D8760273 3EC86A64 521F2B18 177B200C + BBE11757 7A615D6C 770988C0 BAD946E2 08E24FA0 74E5AB31 + 43DB5BFC E0FD108E 4B82D120 A9210801 1A723C12 A787E6D7 + 88719A10 BDBA5B26 99C32718 6AF4E23C 1A946834 B6150BDA + 2583E9CA 2AD44CE8 DBBBC2DB 04DE8EF9 2E8EFC14 1FBECAA6 + 287C5947 4E6BC05D 99B2964F A090C3A2 233BA186 515BE7ED + 1F612970 CEE2D7AF B81BDD76 2170481C D0069127 D5B05AA9 + 93B4EA98 8D8FDDC1 86FFB7DC 90A6C08F 4DF435C9 34028492 + 36C3FAB4 D27C7026 C1D4DCB2 602646DE C9751E76 3DBA37BD + F8FF9406 AD9E530E E5DB382F 413001AE B06A53ED 9027D831 + 179727B0 865A8918 DA3EDBEB CF9B14ED 44CE6CBA CED4BB1B + DB7F1447 E6CC254B 33205151 2BD7AF42 6FB8F401 378CD2BF + 5983CA01 C64B92EC F032EA15 D1721D03 F482D7CE 6E74FEF6 + D55E702F 46980C82 B5A84031 900B1C9E 59E7C97F BEC7E8F3 + 23A97A7E 36CC88BE 0F1D45B7 FF585AC5 4BD407B2 2B4154AA + CC8F6D7E BF48E1D8 14CC5ED2 0F8037E0 A79715EE F29BE328 + 06A1D58B B7C5DA76 F550AA3D 8A1FBFF0 EB19CCB1 A313D55C + DA56C9EC 2EF29632 387FE8D7 6E3C0468 043E8F66 3F4860EE + 12BF2D5B 0B7474D6 E694F91E 6DBE1159 74A3926F 12FEE5E4 + 38777CB6 A932DF8C D8BEC4D0 73B931BA 3BC832B6 8D9DD300 + 741FA7BF 8AFC47ED 2576F693 6BA42466 3AAB639C 5AE4F568 + 3423B474 2BF1C978 238F16CB E39D652D E3FDB8BE FC848AD9 + 22222E04 A4037C07 13EB57A8 1A23F0C7 3473FC64 6CEA306B + 4BCBC886 2F8385DD FA9D4B7F A2C087E8 79683303 ED5BDD3A + 062B3CF5 B3A278A6 6D2A13F8 3F44F82D DF310EE0 74AB6A36 + 4597E899 A0255DC1 64F31CC5 0846851D F9AB4819 5DED7EA1 + B1D510BD 7EE74D73 FAF36BC3 1ECFA268 359046F4 EB879F92 + 4009438B 481C6CD7 889A002E D5EE382B C9190DA6 FC026E47 + 9558E447 5677E9AA 9E3050E2 765694DF C81F56E8 80B96E71 + 60C980DD 98EDD3DF FFFFFFFF FFFFFFFF + " + ) + .as_slice(), + ), + generator: DhGroupUInt::new(&[2]), }; #[derive(Debug, PartialEq, Eq, Clone)] -pub struct DH { +pub(crate) struct DH { prime_num: BigUint, - generator: usize, - exp_size: u64, + generator: BigUint, private_key: BigUint, public_key: BigUint, shared_secret: BigUint, @@ -58,9 +272,8 @@ pub struct DH { impl DH { pub fn new(group: &DhGroup) -> Self { Self { - prime_num: BigUint::from_bytes_be(group.prime), - generator: group.generator, - exp_size: group.exp_size, + prime_num: BigUint::from_bytes_be(&group.prime), + generator: BigUint::from_bytes_be(&group.generator), private_key: BigUint::default(), public_key: BigUint::default(), shared_secret: BigUint::default(), @@ -70,19 +283,13 @@ impl DH { pub fn generate_private_key(&mut self, is_server: bool) -> BigUint { let q = (&self.prime_num - &BigUint::from(1u8)) / &BigUint::from(2u8); let mut rng = rand::thread_rng(); - self.private_key = rng.gen_biguint_range( - &if is_server { - 1u8.into() - } else { - 2u8.into() - }, - &q, - ); + self.private_key = + rng.gen_biguint_range(&if is_server { 1u8.into() } else { 2u8.into() }, &q); self.private_key.clone() } pub fn generate_public_key(&mut self) -> BigUint { - self.public_key = BigUint::from(self.generator).modpow(&self.private_key, &self.prime_num); + self.public_key = self.generator.modpow(&self.private_key, &self.prime_num); self.public_key.clone() } @@ -109,3 +316,5 @@ impl DH { public_key > &one && public_key < &prime_minus_one } } + +pub(crate) const BUILTIN_SAFE_DH_GROUPS: &[&DhGroup] = &[&DH_GROUP14, &DH_GROUP16]; diff --git a/russh/src/kex/dh/mod.rs b/russh/src/kex/dh/mod.rs index 0a8b8ddf..0bbfdc25 100644 --- a/russh/src/kex/dh/mod.rs +++ b/russh/src/kex/dh/mod.rs @@ -1,56 +1,109 @@ -mod groups; +pub mod groups; use std::marker::PhantomData; use byteorder::{BigEndian, ByteOrder}; use digest::Digest; use groups::DH; -use log::debug; +use log::{error, trace}; use num_bigint::BigUint; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; use sha1::Sha1; -use sha2::Sha256; - -use self::groups::{DhGroup, DH_GROUP1, DH_GROUP14}; -use super::{compute_keys, KexAlgorithm, KexType}; +use sha2::{Sha256, Sha512}; +use ssh_encoding::{Decode, Encode, Reader, Writer}; + +use self::groups::{ + DhGroup, DH_GROUP1, DH_GROUP14, DH_GROUP15, DH_GROUP16, DH_GROUP17, DH_GROUP18, +}; +use super::{compute_keys, KexAlgorithm, KexAlgorithmImplementor, KexType}; +use crate::client::GexParams; use crate::session::Exchange; -use crate::{cipher, mac, msg}; +use crate::{cipher, mac, msg, CryptoVec, Error}; + +pub(crate) struct DhGroup15Sha512KexType {} + +impl KexType for DhGroup15Sha512KexType { + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP15)).into() + } +} + +pub(crate) struct DhGroup17Sha512KexType {} + +impl KexType for DhGroup17Sha512KexType { + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP17)).into() + } +} + +pub(crate) struct DhGroup18Sha512KexType {} + +impl KexType for DhGroup18Sha512KexType { + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP18)).into() + } +} + +pub(crate) struct DhGexSha1KexType {} -pub struct DhGroup1Sha1KexType {} +impl KexType for DhGexSha1KexType { + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(None).into() + } +} + +pub(crate) struct DhGexSha256KexType {} + +impl KexType for DhGexSha256KexType { + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(None).into() + } +} + +pub(crate) struct DhGroup1Sha1KexType {} impl KexType for DhGroup1Sha1KexType { - fn make(&self) -> Box { - Box::new(DhGroupKex::::new(&DH_GROUP1)) as Box + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP1)).into() } } -pub struct DhGroup14Sha1KexType {} + +pub(crate) struct DhGroup14Sha1KexType {} impl KexType for DhGroup14Sha1KexType { - fn make(&self) -> Box { - Box::new(DhGroupKex::::new(&DH_GROUP14)) as Box + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP14)).into() } } -pub struct DhGroup14Sha256KexType {} + +pub(crate) struct DhGroup14Sha256KexType {} impl KexType for DhGroup14Sha256KexType { - fn make(&self) -> Box { - Box::new(DhGroupKex::::new(&DH_GROUP14)) as Box + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP14)).into() + } +} + +pub(crate) struct DhGroup16Sha512KexType {} + +impl KexType for DhGroup16Sha512KexType { + fn make(&self) -> KexAlgorithm { + DhGroupKex::::new(Some(&DH_GROUP16)).into() } } #[doc(hidden)] -pub struct DhGroupKex { - dh: DH, +pub(crate) struct DhGroupKex { + dh: Option, shared_secret: Option>, + is_dh_gex: bool, _digest: PhantomData, } impl DhGroupKex { - pub fn new(group: &DhGroup) -> DhGroupKex { - let dh = DH::new(group); + pub(crate) fn new(group: Option<&DhGroup>) -> DhGroupKex { DhGroupKex { - dh, + dh: group.map(DH::new), shared_secret: None, + is_dh_gex: group.is_none(), _digest: PhantomData, } } @@ -65,7 +118,7 @@ impl std::fmt::Debug for DhGroupKex { } } -fn biguint_to_mpint(biguint: &BigUint) -> Vec { +pub(crate) fn biguint_to_mpint(biguint: &BigUint) -> Vec { let mut mpint = Vec::new(); let bytes = biguint.to_bytes_be(); if let Some(b) = bytes.first() { @@ -77,38 +130,65 @@ fn biguint_to_mpint(biguint: &BigUint) -> Vec { mpint } -impl KexAlgorithm for DhGroupKex { +impl KexAlgorithmImplementor for DhGroupKex { fn skip_exchange(&self) -> bool { false } + fn is_dh_gex(&self) -> bool { + self.is_dh_gex + } + + fn client_dh_gex_init( + &mut self, + gex: &GexParams, + writer: &mut impl Writer, + ) -> Result<(), Error> { + msg::KEX_DH_GEX_REQUEST.encode(writer)?; + (gex.min_group_size() as u32).encode(writer)?; + (gex.preferred_group_size() as u32).encode(writer)?; + (gex.max_group_size() as u32).encode(writer)?; + Ok(()) + } + + #[allow(dead_code)] + fn dh_gex_set_group(&mut self, group: DhGroup) -> Result<(), crate::Error> { + self.dh = Some(DH::new(&group)); + Ok(()) + } + #[doc(hidden)] - fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), crate::Error> { - debug!("server_dh"); + fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), Error> { + let Some(dh) = self.dh.as_mut() else { + error!("DH kex sequence error, dh is None in server_dh"); + return Err(Error::Inconsistent); + }; let client_pubkey = { - if payload.first() != Some(&msg::KEX_ECDH_INIT) { - return Err(crate::Error::Inconsistent); + if payload.first() != Some(&msg::KEX_ECDH_INIT) + && payload.first() != Some(&msg::KEX_DH_GEX_INIT) + { + return Err(Error::Inconsistent); } #[allow(clippy::indexing_slicing)] // length checked let pubkey_len = BigEndian::read_u32(&payload[1..]) as usize; if payload.len() < 5 + pubkey_len { - return Err(crate::Error::Inconsistent); + return Err(Error::Inconsistent); } &payload .get(5..(5 + pubkey_len)) - .ok_or(crate::Error::Inconsistent)? + .ok_or(Error::Inconsistent)? }; - debug!("client_pubkey: {:?}", client_pubkey); + trace!("client_pubkey: {:?}", client_pubkey); - self.dh.generate_private_key(true); - let server_pubkey = &self.dh.generate_public_key(); - if !self.dh.validate_public_key(server_pubkey) { - return Err(crate::Error::Inconsistent); + dh.generate_private_key(true); + let server_pubkey = &dh.generate_public_key(); + if !dh.validate_public_key(server_pubkey) { + return Err(Error::Inconsistent); } let encoded_server_pubkey = biguint_to_mpint(server_pubkey); @@ -118,13 +198,13 @@ impl KexAlgorithm for DhGroupKex { exchange.server_ephemeral.extend(&encoded_server_pubkey); let decoded_client_pubkey = DH::decode_public_key(client_pubkey); - if !self.dh.validate_public_key(&decoded_client_pubkey) { - return Err(crate::Error::Inconsistent); + if !dh.validate_public_key(&decoded_client_pubkey) { + return Err(Error::Inconsistent); } - let shared = self.dh.compute_shared_secret(decoded_client_pubkey); - if !self.dh.validate_shared_secret(&shared) { - return Err(crate::Error::Inconsistent); + let shared = dh.compute_shared_secret(decoded_client_pubkey); + if !dh.validate_shared_secret(&shared) { + return Err(Error::Inconsistent); } self.shared_secret = Some(biguint_to_mpint(&shared)); Ok(()) @@ -134,13 +214,18 @@ impl KexAlgorithm for DhGroupKex { fn client_dh( &mut self, client_ephemeral: &mut CryptoVec, - buf: &mut CryptoVec, - ) -> Result<(), crate::Error> { - self.dh.generate_private_key(false); - let client_pubkey = &self.dh.generate_public_key(); + writer: &mut impl Writer, + ) -> Result<(), Error> { + let Some(dh) = self.dh.as_mut() else { + error!("DH kex sequence error, dh is None in client_dh"); + return Err(Error::Inconsistent); + }; + + dh.generate_private_key(false); + let client_pubkey = &dh.generate_public_key(); - if !self.dh.validate_public_key(client_pubkey) { - return Err(crate::Error::Inconsistent); + if !dh.validate_public_key(client_pubkey) { + return Err(Error::Inconsistent); } // fill exchange. @@ -148,22 +233,32 @@ impl KexAlgorithm for DhGroupKex { client_ephemeral.clear(); client_ephemeral.extend(&encoded_pubkey); - buf.push(msg::KEX_ECDH_INIT); - buf.extend_ssh_string(&encoded_pubkey); + if self.is_dh_gex { + msg::KEX_DH_GEX_INIT.encode(writer)?; + } else { + msg::KEX_ECDH_INIT.encode(writer)?; + } + + encoded_pubkey.encode(writer)?; Ok(()) } - fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), crate::Error> { + fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), Error> { + let Some(dh) = self.dh.as_mut() else { + error!("DH kex sequence error, dh is None in compute_shared_secret"); + return Err(Error::Inconsistent); + }; + let remote_pubkey = DH::decode_public_key(remote_pubkey_); - if !self.dh.validate_public_key(&remote_pubkey) { - return Err(crate::Error::Inconsistent); + if !dh.validate_public_key(&remote_pubkey) { + return Err(Error::Inconsistent); } - let shared = self.dh.compute_shared_secret(remote_pubkey); - if !self.dh.validate_shared_secret(&shared) { - return Err(crate::Error::Inconsistent); + let shared = dh.compute_shared_secret(remote_pubkey); + if !dh.validate_shared_secret(&shared) { + return Err(Error::Inconsistent); } self.shared_secret = Some(biguint_to_mpint(&shared)); Ok(()) @@ -174,20 +269,27 @@ impl KexAlgorithm for DhGroupKex { key: &CryptoVec, exchange: &Exchange, buffer: &mut CryptoVec, - ) -> Result { + ) -> Result { // Computing the exchange hash, see page 7 of RFC 5656. buffer.clear(); - buffer.extend_ssh_string(&exchange.client_id); - buffer.extend_ssh_string(&exchange.server_id); - buffer.extend_ssh_string(&exchange.client_kex_init); - buffer.extend_ssh_string(&exchange.server_kex_init); + exchange.client_id.encode(buffer)?; + exchange.server_id.encode(buffer)?; + exchange.client_kex_init.encode(buffer)?; + exchange.server_kex_init.encode(buffer)?; buffer.extend(key); - buffer.extend_ssh_string(&exchange.client_ephemeral); - buffer.extend_ssh_string(&exchange.server_ephemeral); + + if let Some((gex_params, dh_group)) = &exchange.gex { + gex_params.encode(buffer)?; + biguint_to_mpint(&BigUint::from_bytes_be(&dh_group.prime)).encode(buffer)?; + biguint_to_mpint(&BigUint::from_bytes_be(&dh_group.generator)).encode(buffer)?; + } + + exchange.client_ephemeral.encode(buffer)?; + exchange.server_ephemeral.encode(buffer)?; if let Some(ref shared) = self.shared_secret { - buffer.extend_ssh_mpint(shared); + shared.encode(buffer)?; } let mut hasher = D::new(); @@ -206,7 +308,7 @@ impl KexAlgorithm for DhGroupKex { remote_to_local_mac: mac::Name, local_to_remote_mac: mac::Name, is_server: bool, - ) -> Result { + ) -> Result { compute_keys::( self.shared_secret.as_deref(), session_id, @@ -218,3 +320,27 @@ impl KexAlgorithm for DhGroupKex { ) } } + +impl Encode for GexParams { + fn encoded_len(&self) -> Result { + Ok(0u32.encoded_len()? * 3) + } + + fn encode(&self, writer: &mut impl Writer) -> Result<(), ssh_encoding::Error> { + (self.min_group_size() as u32).encode(writer)?; + (self.preferred_group_size() as u32).encode(writer)?; + (self.max_group_size() as u32).encode(writer)?; + Ok(()) + } +} + +impl Decode for GexParams { + fn decode(reader: &mut impl Reader) -> Result { + let min_group_size = u32::decode(reader)? as usize; + let preferred_group_size = u32::decode(reader)? as usize; + let max_group_size = u32::decode(reader)? as usize; + GexParams::new(min_group_size, preferred_group_size, max_group_size) + } + + type Error = Error; +} diff --git a/russh/src/kex/ecdh_nistp.rs b/russh/src/kex/ecdh_nistp.rs new file mode 100644 index 00000000..80d7aa1e --- /dev/null +++ b/russh/src/kex/ecdh_nistp.rs @@ -0,0 +1,239 @@ +use std::marker::PhantomData; +use std::ops::Deref; + +use byteorder::{BigEndian, ByteOrder}; +use elliptic_curve::ecdh::{EphemeralSecret, SharedSecret}; +use elliptic_curve::point::PointCompression; +use elliptic_curve::sec1::{FromEncodedPoint, ModulusSize, ToEncodedPoint}; +use elliptic_curve::{AffinePoint, Curve, CurveArithmetic, FieldBytesSize}; +use log::debug; +use p256::NistP256; +use p384::NistP384; +use p521::NistP521; +use sha2::{Digest, Sha256, Sha384, Sha512}; +use ssh_encoding::{Encode, Writer}; + +use super::{encode_mpint, KexAlgorithm}; +use crate::kex::{compute_keys, KexAlgorithmImplementor, KexType}; +use crate::mac::{self}; +use crate::session::Exchange; +use crate::{cipher, msg, CryptoVec}; + +pub struct EcdhNistP256KexType {} + +impl KexType for EcdhNistP256KexType { + fn make(&self) -> KexAlgorithm { + EcdhNistPKex:: { + local_secret: None, + shared_secret: None, + _digest: PhantomData, + } + .into() + } +} + +pub struct EcdhNistP384KexType {} + +impl KexType for EcdhNistP384KexType { + fn make(&self) -> KexAlgorithm { + EcdhNistPKex:: { + local_secret: None, + shared_secret: None, + _digest: PhantomData, + } + .into() + } +} + +pub struct EcdhNistP521KexType {} + +impl KexType for EcdhNistP521KexType { + fn make(&self) -> KexAlgorithm { + EcdhNistPKex:: { + local_secret: None, + shared_secret: None, + _digest: PhantomData, + } + .into() + } +} + +#[doc(hidden)] +pub struct EcdhNistPKex { + local_secret: Option>, + shared_secret: Option>, + _digest: PhantomData, +} + +impl std::fmt::Debug for EcdhNistPKex { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Algorithm {{ local_secret: [hidden], shared_secret: [hidden] }}", + ) + } +} + +impl KexAlgorithmImplementor for EcdhNistPKex +where + C: PointCompression, + FieldBytesSize: ModulusSize, + AffinePoint: FromEncodedPoint + ToEncodedPoint, +{ + fn skip_exchange(&self) -> bool { + false + } + + #[doc(hidden)] + fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), crate::Error> { + debug!("server_dh"); + + let client_pubkey = { + if payload.first() != Some(&msg::KEX_ECDH_INIT) { + return Err(crate::Error::Inconsistent); + } + + #[allow(clippy::indexing_slicing)] // length checked + let pubkey_len = BigEndian::read_u32(&payload[1..]) as usize; + + if payload.len() < 5 + pubkey_len { + return Err(crate::Error::Inconsistent); + } + + #[allow(clippy::indexing_slicing)] // length checked + elliptic_curve::PublicKey::::from_sec1_bytes(&payload[5..(5 + pubkey_len)]) + .map_err(|_| crate::Error::Inconsistent)? + }; + + let server_secret = + elliptic_curve::ecdh::EphemeralSecret::::random(&mut rand_core::OsRng); + let server_pubkey = server_secret.public_key(); + + // fill exchange. + exchange.server_ephemeral.clear(); + exchange + .server_ephemeral + .extend(&server_pubkey.to_sec1_bytes()); + let shared = server_secret.diffie_hellman(&client_pubkey); + self.shared_secret = Some(shared); + Ok(()) + } + + #[doc(hidden)] + fn client_dh( + &mut self, + client_ephemeral: &mut CryptoVec, + writer: &mut impl Writer, + ) -> Result<(), crate::Error> { + let client_secret = + elliptic_curve::ecdh::EphemeralSecret::::random(&mut rand_core::OsRng); + let client_pubkey = client_secret.public_key(); + + // fill exchange. + client_ephemeral.clear(); + client_ephemeral.extend(&client_pubkey.to_sec1_bytes()); + + msg::KEX_ECDH_INIT.encode(writer)?; + client_pubkey.to_sec1_bytes().encode(writer)?; + + self.local_secret = Some(client_secret); + Ok(()) + } + + fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), crate::Error> { + let local_secret = self.local_secret.take().ok_or(crate::Error::KexInit)?; + let pubkey = elliptic_curve::PublicKey::::from_sec1_bytes(remote_pubkey_) + .map_err(|_| crate::Error::KexInit)?; + self.shared_secret = Some(local_secret.diffie_hellman(&pubkey)); + Ok(()) + } + + fn compute_exchange_hash( + &self, + key: &CryptoVec, + exchange: &Exchange, + buffer: &mut CryptoVec, + ) -> Result { + // Computing the exchange hash, see page 7 of RFC 5656. + buffer.clear(); + exchange.client_id.deref().encode(buffer)?; + exchange.server_id.deref().encode(buffer)?; + exchange.client_kex_init.deref().encode(buffer)?; + exchange.server_kex_init.deref().encode(buffer)?; + + buffer.extend(key); + exchange.client_ephemeral.deref().encode(buffer)?; + exchange.server_ephemeral.deref().encode(buffer)?; + + if let Some(ref shared) = self.shared_secret { + encode_mpint(shared.raw_secret_bytes(), buffer)?; + } + + let mut hasher = D::new(); + hasher.update(&buffer); + + let mut res = CryptoVec::new(); + res.extend(hasher.finalize().as_slice()); + Ok(res) + } + + fn compute_keys( + &self, + session_id: &CryptoVec, + exchange_hash: &CryptoVec, + cipher: cipher::Name, + remote_to_local_mac: mac::Name, + local_to_remote_mac: mac::Name, + is_server: bool, + ) -> Result { + compute_keys::( + self.shared_secret + .as_ref() + .map(|x| x.raw_secret_bytes() as &[u8]), + session_id, + exchange_hash, + cipher, + remote_to_local_mac, + local_to_remote_mac, + is_server, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_shared_secret() { + let mut party1 = EcdhNistPKex:: { + local_secret: Some(EphemeralSecret::::random(&mut rand_core::OsRng)), + shared_secret: None, + _digest: PhantomData, + }; + let p1_pubkey = party1.local_secret.as_ref().unwrap().public_key(); + + let mut party2 = EcdhNistPKex:: { + local_secret: Some(EphemeralSecret::::random(&mut rand_core::OsRng)), + shared_secret: None, + _digest: PhantomData, + }; + let p2_pubkey = party2.local_secret.as_ref().unwrap().public_key(); + + party1 + .compute_shared_secret(&p2_pubkey.to_sec1_bytes()) + .unwrap(); + + party2 + .compute_shared_secret(&p1_pubkey.to_sec1_bytes()) + .unwrap(); + + let p1_shared_secret = party1.shared_secret.unwrap(); + let p2_shared_secret = party2.shared_secret.unwrap(); + + assert_eq!( + p1_shared_secret.raw_secret_bytes(), + p2_shared_secret.raw_secret_bytes() + ) + } +} diff --git a/russh/src/kex/mod.rs b/russh/src/kex/mod.rs index cc413d65..eb9dea5e 100644 --- a/russh/src/kex/mod.rs +++ b/russh/src/kex/mod.rs @@ -16,53 +16,175 @@ //! //! This module exports kex algorithm names for use with [Preferred]. mod curve25519; -mod dh; +pub mod dh; +mod ecdh_nistp; mod none; use std::cell::RefCell; use std::collections::HashMap; +use std::convert::TryFrom; use std::fmt::Debug; +use std::ops::DerefMut; use curve25519::Curve25519KexType; -use dh::{DhGroup14Sha1KexType, DhGroup14Sha256KexType, DhGroup1Sha1KexType}; +use delegate::delegate; +use dh::groups::DhGroup; +use dh::{ + DhGexSha1KexType, DhGexSha256KexType, DhGroup14Sha1KexType, DhGroup14Sha256KexType, + DhGroup15Sha512KexType, DhGroup16Sha512KexType, DhGroup17Sha512KexType, DhGroup18Sha512KexType, + DhGroup1Sha1KexType, +}; use digest::Digest; +use ecdh_nistp::{EcdhNistP256KexType, EcdhNistP384KexType, EcdhNistP521KexType}; +use enum_dispatch::enum_dispatch; use once_cell::sync::Lazy; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; +use p256::NistP256; +use p384::NistP384; +use p521::NistP521; +use sha1::Sha1; +use sha2::{Sha256, Sha384, Sha512}; +use ssh_encoding::{Encode, Writer}; +use ssh_key::PublicKey; -use crate::cipher; use crate::cipher::CIPHERS; +use crate::client::GexParams; use crate::mac::{self, MACS}; -use crate::session::Exchange; +use crate::negotiation::Names; +use crate::session::{Exchange, NewKeys}; +use crate::{cipher, CryptoVec, Error}; + +#[derive(Debug)] +pub(crate) enum SessionKexState { + Idle, + InProgress(K), + Taken, // some async activity still going on such as host key checks +} + +impl PartialEq for SessionKexState { + fn eq(&self, other: &Self) -> bool { + core::mem::discriminant(self) == core::mem::discriminant(other) + } +} + +impl SessionKexState { + pub fn active(&self) -> bool { + match self { + SessionKexState::Idle => false, + SessionKexState::InProgress(_) => true, + SessionKexState::Taken => true, + } + } + + pub fn take(&mut self) -> Self { + // TODO maybe make this take a guarded closure + std::mem::replace( + self, + match self { + SessionKexState::Idle => SessionKexState::Idle, + _ => SessionKexState::Taken, + }, + ) + } +} + +#[derive(Debug)] +pub(crate) enum KexCause { + Initial, + Rekey { strict: bool, session_id: CryptoVec }, +} + +impl KexCause { + pub fn is_strict_kex(&self, names: &Names) -> bool { + names.strict_kex || matches!(self, Self::Rekey { strict: true, .. }) + } + + pub fn is_rekey(&self) -> bool { + match self { + Self::Initial => false, + Self::Rekey { .. } => true, + } + } + + pub fn session_id(&self) -> Option<&CryptoVec> { + match self { + Self::Initial => None, + Self::Rekey { session_id, .. } => Some(session_id), + } + } +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub(crate) enum KexProgress { + NeedsReply { + kex: T, + reset_seqn: bool, + }, + Done { + server_host_key: Option, + newkeys: NewKeys, + }, +} + +#[enum_dispatch(KexAlgorithmImplementor)] +pub(crate) enum KexAlgorithm { + DhGroupKexSha1(dh::DhGroupKex), + DhGroupKexSha256(dh::DhGroupKex), + DhGroupKexSha512(dh::DhGroupKex), + Curve25519Kex(curve25519::Curve25519Kex), + EcdhNistP256Kex(ecdh_nistp::EcdhNistPKex), + EcdhNistP384Kex(ecdh_nistp::EcdhNistPKex), + EcdhNistP521Kex(ecdh_nistp::EcdhNistPKex), + None(none::NoneKexAlgorithm), +} pub(crate) trait KexType { - fn make(&self) -> Box; + fn make(&self) -> KexAlgorithm; } -impl Debug for dyn KexAlgorithm + Send { +impl Debug for KexAlgorithm { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "KexAlgorithm") } } -pub(crate) trait KexAlgorithm { +#[enum_dispatch] +pub(crate) trait KexAlgorithmImplementor { fn skip_exchange(&self) -> bool; + fn is_dh_gex(&self) -> bool { + false + } + + #[allow(unused_variables)] + fn client_dh_gex_init( + &mut self, + gex: &GexParams, + writer: &mut impl Writer, + ) -> Result<(), Error> { + Err(Error::KexInit) + } - fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), crate::Error>; + #[allow(unused_variables)] + fn dh_gex_set_group(&mut self, group: DhGroup) -> Result<(), Error> { + Err(Error::KexInit) + } + + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] + fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), Error>; fn client_dh( &mut self, client_ephemeral: &mut CryptoVec, - buf: &mut CryptoVec, - ) -> Result<(), crate::Error>; + writer: &mut impl Writer, + ) -> Result<(), Error>; - fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), crate::Error>; + fn compute_shared_secret(&mut self, remote_pubkey_: &[u8]) -> Result<(), Error>; fn compute_exchange_hash( &self, key: &CryptoVec, exchange: &Exchange, buffer: &mut CryptoVec, - ) -> Result; + ) -> Result; fn compute_keys( &self, @@ -72,7 +194,7 @@ pub(crate) trait KexAlgorithm { remote_to_local_mac: mac::Name, local_to_remote_mac: mac::Name, is_server: bool, - ) -> Result; + ) -> Result; } #[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] @@ -83,38 +205,111 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + KEXES.keys().find(|x| x.0 == s).map(|x| **x).ok_or(()) + } +} + /// `curve25519-sha256` pub const CURVE25519: Name = Name("curve25519-sha256"); /// `curve25519-sha256@libssh.org` pub const CURVE25519_PRE_RFC_8731: Name = Name("curve25519-sha256@libssh.org"); +/// `diffie-hellman-group-exchange-sha1`. +pub const DH_GEX_SHA1: Name = Name("diffie-hellman-group-exchange-sha1"); +/// `diffie-hellman-group-exchange-sha256`. +pub const DH_GEX_SHA256: Name = Name("diffie-hellman-group-exchange-sha256"); /// `diffie-hellman-group1-sha1` pub const DH_G1_SHA1: Name = Name("diffie-hellman-group1-sha1"); /// `diffie-hellman-group14-sha1` pub const DH_G14_SHA1: Name = Name("diffie-hellman-group14-sha1"); /// `diffie-hellman-group14-sha256` pub const DH_G14_SHA256: Name = Name("diffie-hellman-group14-sha256"); +/// `diffie-hellman-group15-sha512` +pub const DH_G15_SHA512: Name = Name("diffie-hellman-group15-sha512"); +/// `diffie-hellman-group16-sha512` +pub const DH_G16_SHA512: Name = Name("diffie-hellman-group16-sha512"); +/// `diffie-hellman-group17-sha512` +pub const DH_G17_SHA512: Name = Name("diffie-hellman-group17-sha512"); +/// `diffie-hellman-group18-sha512` +pub const DH_G18_SHA512: Name = Name("diffie-hellman-group18-sha512"); +/// `ecdh-sha2-nistp256` +pub const ECDH_SHA2_NISTP256: Name = Name("ecdh-sha2-nistp256"); +/// `ecdh-sha2-nistp384` +pub const ECDH_SHA2_NISTP384: Name = Name("ecdh-sha2-nistp384"); +/// `ecdh-sha2-nistp521` +pub const ECDH_SHA2_NISTP521: Name = Name("ecdh-sha2-nistp521"); /// `none` pub const NONE: Name = Name("none"); /// `ext-info-c` pub const EXTENSION_SUPPORT_AS_CLIENT: Name = Name("ext-info-c"); /// `ext-info-s` pub const EXTENSION_SUPPORT_AS_SERVER: Name = Name("ext-info-s"); +/// `kex-strict-c-v00@openssh.com` +pub const EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT: Name = Name("kex-strict-c-v00@openssh.com"); +/// `kex-strict-s-v00@openssh.com` +pub const EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER: Name = Name("kex-strict-s-v00@openssh.com"); const _CURVE25519: Curve25519KexType = Curve25519KexType {}; +const _DH_GEX_SHA1: DhGexSha1KexType = DhGexSha1KexType {}; +const _DH_GEX_SHA256: DhGexSha256KexType = DhGexSha256KexType {}; const _DH_G1_SHA1: DhGroup1Sha1KexType = DhGroup1Sha1KexType {}; const _DH_G14_SHA1: DhGroup14Sha1KexType = DhGroup14Sha1KexType {}; const _DH_G14_SHA256: DhGroup14Sha256KexType = DhGroup14Sha256KexType {}; +const _DH_G15_SHA512: DhGroup15Sha512KexType = DhGroup15Sha512KexType {}; +const _DH_G16_SHA512: DhGroup16Sha512KexType = DhGroup16Sha512KexType {}; +const _DH_G17_SHA512: DhGroup17Sha512KexType = DhGroup17Sha512KexType {}; +const _DH_G18_SHA512: DhGroup18Sha512KexType = DhGroup18Sha512KexType {}; +const _ECDH_SHA2_NISTP256: EcdhNistP256KexType = EcdhNistP256KexType {}; +const _ECDH_SHA2_NISTP384: EcdhNistP384KexType = EcdhNistP384KexType {}; +const _ECDH_SHA2_NISTP521: EcdhNistP521KexType = EcdhNistP521KexType {}; const _NONE: none::NoneKexType = none::NoneKexType {}; +pub const ALL_KEX_ALGORITHMS: &[&Name] = &[ + &CURVE25519, + &CURVE25519_PRE_RFC_8731, + &DH_GEX_SHA1, + &DH_GEX_SHA256, + &DH_G1_SHA1, + &DH_G14_SHA1, + &DH_G14_SHA256, + &DH_G15_SHA512, + &DH_G16_SHA512, + &DH_G17_SHA512, + &DH_G18_SHA512, + &ECDH_SHA2_NISTP256, + &ECDH_SHA2_NISTP384, + &ECDH_SHA2_NISTP521, + &NONE, +]; + pub(crate) static KEXES: Lazy> = Lazy::new(|| { let mut h: HashMap<&'static Name, &(dyn KexType + Send + Sync)> = HashMap::new(); h.insert(&CURVE25519, &_CURVE25519); h.insert(&CURVE25519_PRE_RFC_8731, &_CURVE25519); + h.insert(&DH_GEX_SHA1, &_DH_GEX_SHA1); + h.insert(&DH_GEX_SHA256, &_DH_GEX_SHA256); + h.insert(&DH_G18_SHA512, &_DH_G18_SHA512); + h.insert(&DH_G17_SHA512, &_DH_G17_SHA512); + h.insert(&DH_G16_SHA512, &_DH_G16_SHA512); + h.insert(&DH_G15_SHA512, &_DH_G15_SHA512); h.insert(&DH_G14_SHA256, &_DH_G14_SHA256); h.insert(&DH_G14_SHA1, &_DH_G14_SHA1); h.insert(&DH_G1_SHA1, &_DH_G1_SHA1); + h.insert(&ECDH_SHA2_NISTP256, &_ECDH_SHA2_NISTP256); + h.insert(&ECDH_SHA2_NISTP384, &_ECDH_SHA2_NISTP384); + h.insert(&ECDH_SHA2_NISTP521, &_ECDH_SHA2_NISTP521); h.insert(&NONE, &_NONE); + assert_eq!(ALL_KEX_ALGORITHMS.len(), h.len()); h }); @@ -133,27 +328,23 @@ pub(crate) fn compute_keys( remote_to_local_mac: mac::Name, local_to_remote_mac: mac::Name, is_server: bool, -) -> Result { - let cipher = CIPHERS.get(&cipher).ok_or(crate::Error::UnknownAlgo)?; - let remote_to_local_mac = MACS - .get(&remote_to_local_mac) - .ok_or(crate::Error::UnknownAlgo)?; - let local_to_remote_mac = MACS - .get(&local_to_remote_mac) - .ok_or(crate::Error::UnknownAlgo)?; +) -> Result { + let cipher = CIPHERS.get(&cipher).ok_or(Error::UnknownAlgo)?; + let remote_to_local_mac = MACS.get(&remote_to_local_mac).ok_or(Error::UnknownAlgo)?; + let local_to_remote_mac = MACS.get(&local_to_remote_mac).ok_or(Error::UnknownAlgo)?; // https://tools.ietf.org/html/rfc4253#section-7.2 BUFFER.with(|buffer| { KEY_BUF.with(|key| { NONCE_BUF.with(|nonce| { MAC_BUF.with(|mac| { - let compute_key = |c, key: &mut CryptoVec, len| -> Result<(), crate::Error> { + let compute_key = |c, key: &mut CryptoVec, len| -> Result<(), Error> { let mut buffer = buffer.borrow_mut(); buffer.clear(); key.clear(); if let Some(shared) = shared_secret { - buffer.extend_ssh_mpint(shared); + encode_mpint(shared, buffer.deref_mut())?; } buffer.extend(exchange_hash.as_ref()); @@ -170,7 +361,7 @@ pub(crate) fn compute_keys( // extend. buffer.clear(); if let Some(shared) = shared_secret { - buffer.extend_ssh_mpint(shared); + encode_mpint(shared, buffer.deref_mut())?; } buffer.extend(exchange_hash.as_ref()); buffer.extend(key); @@ -238,3 +429,23 @@ pub(crate) fn compute_keys( }) }) } + +// NOTE: using MpInt::from_bytes().encode() will randomly fail, +// I'm assuming it's due to specific byte values / padding but no time to investigate +#[allow(clippy::indexing_slicing)] // length is known +pub(crate) fn encode_mpint(s: &[u8], w: &mut W) -> Result<(), Error> { + // Skip initial 0s. + let mut i = 0; + while i < s.len() && s[i] == 0 { + i += 1 + } + // If the first non-zero is >= 128, write its length (u32, BE), followed by 0. + if s[i] & 0x80 != 0 { + ((s.len() - i + 1) as u32).encode(w)?; + 0u8.encode(w)?; + } else { + ((s.len() - i) as u32).encode(w)?; + } + w.write(&s[i..])?; + Ok(()) +} diff --git a/russh/src/kex/none.rs b/russh/src/kex/none.rs index 66d903ac..91678abb 100644 --- a/russh/src/kex/none.rs +++ b/russh/src/kex/none.rs @@ -1,18 +1,20 @@ -use russh_cryptovec::CryptoVec; +use ssh_encoding::Writer; -use super::{KexAlgorithm, KexType}; +use super::{KexAlgorithm, KexAlgorithmImplementor, KexType}; +use crate::CryptoVec; pub struct NoneKexType {} impl KexType for NoneKexType { - fn make(&self) -> Box { - Box::new(NoneKexAlgorithm {}) as Box + fn make(&self) -> KexAlgorithm { + NoneKexAlgorithm {}.into() } } -struct NoneKexAlgorithm {} +#[doc(hidden)] +pub struct NoneKexAlgorithm {} -impl KexAlgorithm for NoneKexAlgorithm { +impl KexAlgorithmImplementor for NoneKexAlgorithm { fn skip_exchange(&self) -> bool { true } @@ -28,7 +30,7 @@ impl KexAlgorithm for NoneKexAlgorithm { fn client_dh( &mut self, _client_ephemeral: &mut russh_cryptovec::CryptoVec, - _buf: &mut russh_cryptovec::CryptoVec, + _buf: &mut impl Writer, ) -> Result<(), crate::Error> { Ok(()) } diff --git a/russh/src/key.rs b/russh/src/key.rs deleted file mode 100644 index 17a28f68..00000000 --- a/russh/src/key.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2016 Pierre-Étienne Meunier -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::*; -use russh_keys::key::*; - -#[doc(hidden)] -pub trait PubKey { - fn push_to(&self, buffer: &mut CryptoVec); -} - -impl PubKey for PublicKey { - fn push_to(&self, buffer: &mut CryptoVec) { - match self { - PublicKey::Ed25519(ref public) => { - buffer.push_u32_be((ED25519.0.len() + public.as_bytes().len() + 8) as u32); - buffer.extend_ssh_string(ED25519.0.as_bytes()); - buffer.extend_ssh_string(public.as_bytes()); - } - #[cfg(feature = "openssl")] - PublicKey::RSA { ref key, .. } => { - #[allow(clippy::unwrap_used)] // type known - let rsa = key.0.rsa().unwrap(); - let e = rsa.e().to_vec(); - let n = rsa.n().to_vec(); - buffer.push_u32_be((4 + SSH_RSA.0.len() + mpint_len(&n) + mpint_len(&e)) as u32); - buffer.extend_ssh_string(SSH_RSA.0.as_bytes()); - buffer.extend_ssh_mpint(&e); - buffer.extend_ssh_mpint(&n); - } - } - } -} - -impl PubKey for KeyPair { - fn push_to(&self, buffer: &mut CryptoVec) { - match self { - KeyPair::Ed25519(ref key) => { - let public = key.verifying_key().to_bytes(); - buffer.push_u32_be((ED25519.0.len() + public.len() + 8) as u32); - buffer.extend_ssh_string(ED25519.0.as_bytes()); - buffer.extend_ssh_string(public.as_slice()); - } - #[cfg(feature = "openssl")] - KeyPair::RSA { ref key, .. } => { - let e = key.e().to_vec(); - let n = key.n().to_vec(); - buffer.push_u32_be((4 + SSH_RSA.0.len() + mpint_len(&n) + mpint_len(&e)) as u32); - buffer.extend_ssh_string(SSH_RSA.0.as_bytes()); - buffer.extend_ssh_mpint(&e); - buffer.extend_ssh_mpint(&n); - } - } - } -} diff --git a/russh/src/keys/agent/client.rs b/russh/src/keys/agent/client.rs new file mode 100644 index 00000000..3026075a --- /dev/null +++ b/russh/src/keys/agent/client.rs @@ -0,0 +1,475 @@ +use core::str; + +use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; +use log::{debug, error}; +use ssh_encoding::{Decode, Encode, Reader}; +use ssh_key::{Algorithm, HashAlg, PrivateKey, PublicKey, Signature}; +use tokio; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use super::{msg, Constraint}; +use crate::helpers::EncodedExt; +use crate::keys::{key, Error}; +use crate::CryptoVec; + +pub trait AgentStream: AsyncRead + AsyncWrite {} + +impl AgentStream for S {} + +/// SSH agent client. +pub struct AgentClient { + stream: S, + buf: CryptoVec, +} + +impl AgentClient { + /// Wraps the internal stream in a Box, allowing different client + /// implementations to have the same type + pub fn dynamic(self) -> AgentClient> { + AgentClient { + stream: Box::new(self.stream), + buf: self.buf, + } + } + + pub fn into_inner(self) -> Box { + Box::new(self.stream) + } +} + +// https://tools.ietf.org/html/draft-miller-ssh-agent-00#section-4.1 +impl AgentClient { + /// Build a future that connects to an SSH agent via the provided + /// stream (on Unix, usually a Unix-domain socket). + pub fn connect(stream: S) -> Self { + AgentClient { + stream, + buf: CryptoVec::new(), + } + } +} + +#[cfg(unix)] +impl AgentClient { + /// Connect to an SSH agent via the provided + /// stream (on Unix, usually a Unix-domain socket). + pub async fn connect_uds>(path: P) -> Result { + let stream = tokio::net::UnixStream::connect(path).await?; + Ok(AgentClient { + stream, + buf: CryptoVec::new(), + }) + } + + /// Connect to an SSH agent specified by the SSH_AUTH_SOCK + /// environment variable. + pub async fn connect_env() -> Result { + let var = if let Ok(var) = std::env::var("SSH_AUTH_SOCK") { + var + } else { + return Err(Error::EnvVar("SSH_AUTH_SOCK")); + }; + match Self::connect_uds(var).await { + Err(Error::IO(io_err)) if io_err.kind() == std::io::ErrorKind::NotFound => { + Err(Error::BadAuthSock) + } + owise => owise, + } + } +} + +#[cfg(windows)] +const ERROR_PIPE_BUSY: u32 = 231u32; + +#[cfg(windows)] +impl AgentClient { + /// Connect to a running Pageant instance + pub async fn connect_pageant() -> Self { + Self::connect(pageant::PageantStream::new()) + } +} + +#[cfg(windows)] +impl AgentClient { + /// Connect to an SSH agent via a Windows named pipe + pub async fn connect_named_pipe>(path: P) -> Result { + let stream = loop { + match tokio::net::windows::named_pipe::ClientOptions::new().open(path.as_ref()) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(e) => return Err(e.into()), + } + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + }; + + Ok(AgentClient { + stream, + buf: CryptoVec::new(), + }) + } +} + +impl AgentClient { + async fn read_response(&mut self) -> Result<(), Error> { + // Writing the message + self.stream.write_all(&self.buf).await?; + self.stream.flush().await?; + + // Reading the length + self.buf.clear(); + self.buf.resize(4); + self.stream.read_exact(&mut self.buf).await?; + + // Reading the rest of the buffer + let len = BigEndian::read_u32(&self.buf) as usize; + self.buf.clear(); + self.buf.resize(len); + self.stream.read_exact(&mut self.buf).await?; + + Ok(()) + } + + async fn read_success(&mut self) -> Result<(), Error> { + self.read_response().await?; + if self.buf.first() == Some(&msg::SUCCESS) { + Ok(()) + } else { + Err(Error::AgentFailure) + } + } + + /// Send a key to the agent, with a (possibly empty) slice of + /// constraints to apply when using the key to sign. + pub async fn add_identity( + &mut self, + key: &PrivateKey, + constraints: &[Constraint], + ) -> Result<(), Error> { + // See IETF draft-miller-ssh-agent-13, section 3.2 for format. + // https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent + self.buf.clear(); + self.buf.resize(4); + if constraints.is_empty() { + self.buf.push(msg::ADD_IDENTITY) + } else { + self.buf.push(msg::ADD_ID_CONSTRAINED) + } + + key.key_data().encode(&mut self.buf)?; + "".encode(&mut self.buf)?; // comment field + + if !constraints.is_empty() { + for cons in constraints { + match *cons { + Constraint::KeyLifetime { seconds } => { + msg::CONSTRAIN_LIFETIME.encode(&mut self.buf)?; + seconds.encode(&mut self.buf)?; + } + Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), + Constraint::Extensions { + ref name, + ref details, + } => { + msg::CONSTRAIN_EXTENSION.encode(&mut self.buf)?; + name.encode(&mut self.buf)?; + details.encode(&mut self.buf)?; + } + } + } + } + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + + self.read_success().await?; + Ok(()) + } + + /// Add a smart card to the agent, with a (possibly empty) set of + /// constraints to apply when signing. + pub async fn add_smartcard_key( + &mut self, + id: &str, + pin: &[u8], + constraints: &[Constraint], + ) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + if constraints.is_empty() { + self.buf.push(msg::ADD_SMARTCARD_KEY) + } else { + self.buf.push(msg::ADD_SMARTCARD_KEY_CONSTRAINED) + } + id.encode(&mut self.buf)?; + pin.encode(&mut self.buf)?; + if !constraints.is_empty() { + (constraints.len() as u32).encode(&mut self.buf)?; + for cons in constraints { + match *cons { + Constraint::KeyLifetime { seconds } => { + msg::CONSTRAIN_LIFETIME.encode(&mut self.buf)?; + seconds.encode(&mut self.buf)?; + } + Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), + Constraint::Extensions { + ref name, + ref details, + } => { + msg::CONSTRAIN_EXTENSION.encode(&mut self.buf)?; + name.encode(&mut self.buf)?; + details.encode(&mut self.buf)?; + } + } + } + } + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + self.read_response().await?; + Ok(()) + } + + /// Lock the agent, making it refuse to sign until unlocked. + pub async fn lock(&mut self, passphrase: &[u8]) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + self.buf.push(msg::LOCK); + passphrase.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + self.read_response().await?; + Ok(()) + } + + /// Unlock the agent, allowing it to sign again. + pub async fn unlock(&mut self, passphrase: &[u8]) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + msg::UNLOCK.encode(&mut self.buf)?; + passphrase.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + #[allow(clippy::indexing_slicing)] // static length + BigEndian::write_u32(&mut self.buf[..], len as u32); + self.read_response().await?; + Ok(()) + } + + /// Ask the agent for a list of the currently registered secret + /// keys. + pub async fn request_identities(&mut self) -> Result, Error> { + self.buf.clear(); + self.buf.resize(4); + msg::REQUEST_IDENTITIES.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + + self.read_response().await?; + debug!("identities: {:?}", &self.buf[..]); + let mut keys = Vec::new(); + + #[allow(clippy::indexing_slicing)] // static length + if let Some((&msg::IDENTITIES_ANSWER, mut r)) = self.buf.split_first() { + let n = u32::decode(&mut r)?; + for _ in 0..n { + let key_blob = Bytes::decode(&mut r)?; + let comment = String::decode(&mut r)?; + let mut key = key::parse_public_key(&key_blob)?; + key.set_comment(comment); + keys.push(key); + } + } + + Ok(keys) + } + + /// Ask the agent to sign the supplied piece of data. + pub async fn sign_request( + &mut self, + public: &PublicKey, + hash_alg: Option, + mut data: CryptoVec, + ) -> Result { + debug!("sign_request: {:?}", data); + let hash = self.prepare_sign_request(public, hash_alg, &data)?; + + self.read_response().await?; + + match self.buf.split_first() { + Some((&msg::SIGN_RESPONSE, mut r)) => { + self.write_signature(&mut r, hash, &mut data)?; + Ok(data) + } + Some((&msg::FAILURE, _)) => Err(Error::AgentFailure), + _ => { + debug!("self.buf = {:?}", &self.buf[..]); + Err(Error::AgentProtocolError) + } + } + } + + fn prepare_sign_request( + &mut self, + public: &ssh_key::PublicKey, + hash_alg: Option, + data: &[u8], + ) -> Result { + self.buf.clear(); + self.buf.resize(4); + msg::SIGN_REQUEST.encode(&mut self.buf)?; + public.key_data().encoded()?.encode(&mut self.buf)?; + data.encode(&mut self.buf)?; + debug!("public = {:?}", public); + + let hash = match public.algorithm() { + Algorithm::Rsa { .. } => match hash_alg { + Some(HashAlg::Sha256) => 2, + Some(HashAlg::Sha512) => 4, + _ => 0, + }, + _ => 0, + }; + + hash.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + Ok(hash) + } + + fn write_signature( + &self, + r: &mut R, + hash: u32, + data: &mut CryptoVec, + ) -> Result<(), Error> { + let mut resp = &Bytes::decode(r)?[..]; + let t = String::decode(&mut resp)?; + if (hash == 2 && t == "rsa-sha2-256") || (hash == 4 && t == "rsa-sha2-512") || hash == 0 { + let sig = Bytes::decode(&mut resp)?; + (t.len() + sig.len() + 8).encode(data)?; + t.encode(data)?; + sig.encode(data)?; + Ok(()) + } else { + error!("unexpected agent signature type: {:?}", t); + Err(Error::AgentProtocolError) + } + } + + /// Ask the agent to sign the supplied piece of data. + pub fn sign_request_base64( + mut self, + public: &ssh_key::PublicKey, + hash_alg: Option, + data: &[u8], + ) -> impl futures::Future)> { + debug!("sign_request: {:?}", data); + let r = self.prepare_sign_request(public, hash_alg, data); + async move { + if let Err(e) = r { + return (self, Err(e)); + } + + let resp = self.read_response().await; + if let Err(e) = resp { + return (self, Err(e)); + } + + #[allow(clippy::indexing_slicing)] // length is checked + if !self.buf.is_empty() && self.buf[0] == msg::SIGN_RESPONSE { + let base64 = data_encoding::BASE64_NOPAD.encode(&self.buf[1..]); + (self, Ok(base64)) + } else { + (self, Ok(String::new())) + } + } + } + + /// Ask the agent to sign the supplied piece of data, and return a `Signature`. + pub async fn sign_request_signature( + &mut self, + public: &ssh_key::PublicKey, + hash_alg: Option, + data: &[u8], + ) -> Result { + debug!("sign_request: {:?}", data); + + self.prepare_sign_request(public, hash_alg, data)?; + self.read_response().await?; + + match self.buf.split_first() { + Some((&msg::SIGN_RESPONSE, mut r)) => { + let mut resp = &Bytes::decode(&mut r)?[..]; + let sig = Signature::decode(&mut resp)?; + Ok(sig) + } + _ => Err(Error::AgentProtocolError), + } + } + + /// Ask the agent to remove a key from its memory. + pub async fn remove_identity(&mut self, public: &ssh_key::PublicKey) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + self.buf.push(msg::REMOVE_IDENTITY); + public.key_data().encoded()?.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + self.read_response().await?; + Ok(()) + } + + /// Ask the agent to remove a smartcard from its memory. + pub async fn remove_smartcard_key(&mut self, id: &str, pin: &[u8]) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + msg::REMOVE_SMARTCARD_KEY.encode(&mut self.buf)?; + id.encode(&mut self.buf)?; + pin.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + BigEndian::write_u32(&mut self.buf[..], len as u32); + self.read_response().await?; + Ok(()) + } + + /// Ask the agent to forget all known keys. + pub async fn remove_all_identities(&mut self) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + msg::REMOVE_ALL_IDENTITIES.encode(&mut self.buf)?; + 1u32.encode(&mut self.buf)?; + self.read_success().await?; + Ok(()) + } + + /// Send a custom message to the agent. + pub async fn extension(&mut self, typ: &[u8], ext: &[u8]) -> Result<(), Error> { + self.buf.clear(); + self.buf.resize(4); + msg::EXTENSION.encode(&mut self.buf)?; + typ.encode(&mut self.buf)?; + ext.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + (len as u32).encode(&mut self.buf)?; + self.read_response().await?; + Ok(()) + } + + /// Ask the agent what extensions about supported extensions. + pub async fn query_extension(&mut self, typ: &[u8], mut ext: CryptoVec) -> Result { + self.buf.clear(); + self.buf.resize(4); + msg::EXTENSION.encode(&mut self.buf)?; + typ.encode(&mut self.buf)?; + let len = self.buf.len() - 4; + (len as u32).encode(&mut self.buf)?; + self.read_response().await?; + + match self.buf.split_first() { + Some((&msg::SUCCESS, mut r)) => { + ext.extend(&Bytes::decode(&mut r)?); + Ok(true) + } + _ => Ok(false), + } + } +} diff --git a/russh-keys/src/agent/mod.rs b/russh/src/keys/agent/mod.rs similarity index 100% rename from russh-keys/src/agent/mod.rs rename to russh/src/keys/agent/mod.rs diff --git a/russh-keys/src/agent/msg.rs b/russh/src/keys/agent/msg.rs similarity index 89% rename from russh-keys/src/agent/msg.rs rename to russh/src/keys/agent/msg.rs index a77c5091..d732e674 100644 --- a/russh-keys/src/agent/msg.rs +++ b/russh/src/keys/agent/msg.rs @@ -19,4 +19,5 @@ pub const EXTENSION: u8 = 27; pub const CONSTRAIN_LIFETIME: u8 = 1; pub const CONSTRAIN_CONFIRM: u8 = 2; -pub const CONSTRAIN_EXTENSION: u8 = 3; +// pub const CONSTRAIN_MAXSIGN: u8 = 3; +pub const CONSTRAIN_EXTENSION: u8 = 255; diff --git a/russh-keys/src/agent/server.rs b/russh/src/keys/agent/server.rs similarity index 56% rename from russh-keys/src/agent/server.rs rename to russh/src/keys/agent/server.rs index c61a8e0d..a3833c51 100644 --- a/russh-keys/src/agent/server.rs +++ b/russh/src/keys/agent/server.rs @@ -1,27 +1,27 @@ use std::collections::HashMap; -use std::convert::TryFrom; use std::marker::Sync; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; -use async_trait::async_trait; use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; use futures::future::Future; use futures::stream::{Stream, StreamExt}; -use russh_cryptovec::CryptoVec; +use ssh_encoding::{Decode, Encode, Reader}; +use ssh_key::PrivateKey; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::sleep; use {std, tokio}; use super::{msg, Constraint}; -use crate::encoding::{Encoding, Position, Reader}; -#[cfg(feature = "openssl")] -use crate::key::SignatureHash; -use crate::{key, Error}; +use crate::helpers::{sign_with_hash_alg, EncodedExt}; +use crate::keys::key::PrivateKeyWithHashAlg; +use crate::keys::Error; +use crate::CryptoVec; #[derive(Clone)] #[allow(clippy::type_complexity)] -struct KeyStore(Arc, (Arc, SystemTime, Vec)>>>); +struct KeyStore(Arc, (Arc, SystemTime, Vec)>>>); #[derive(Clone)] struct Lock(Arc>); @@ -43,17 +43,17 @@ pub enum MessageType { Unlock, } -#[async_trait] +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] pub trait Agent: Clone + Send + 'static { fn confirm( self, - _pk: Arc, + _pk: Arc, ) -> Box + Unpin + Send> { Box::new(futures::future::ready((self, true))) } - async fn confirm_request(&self, _msg: MessageType) -> bool { - true + fn confirm_request(&self, _msg: MessageType) -> impl Future + Send { + async { true } } } @@ -68,7 +68,7 @@ where while let Some(Ok(stream)) = listener.next().await { let mut buf = CryptoVec::new(); buf.resize(4); - tokio::spawn( + russh_util::runtime::spawn( (Connection { lock: lock.clone(), keys: keys.clone(), @@ -83,10 +83,7 @@ where } impl Agent for () { - fn confirm( - self, - _: Arc, - ) -> Box + Unpin + Send> { + fn confirm(self, _: Arc) -> Box + Unpin + Send> { Box::new(futures::future::ready((self, true))) } } @@ -131,26 +128,30 @@ impl { + + match self.buf.split_first() { + Some((&11, _)) + if !is_locked && agentref.confirm_request(MessageType::RequestKeys).await => + { // request identities if let Ok(keys) = self.keys.0.read() { - writebuf.push(msg::IDENTITIES_ANSWER); - writebuf.push_u32_be(keys.len() as u32); + msg::IDENTITIES_ANSWER.encode(writebuf)?; + (keys.len() as u32).encode(writebuf)?; for (k, _) in keys.iter() { - writebuf.extend_ssh_string(k); - writebuf.extend_ssh_string(b""); + k.encode(writebuf)?; + "".encode(writebuf)?; } } else { - writebuf.push(msg::FAILURE) + msg::FAILURE.encode(writebuf)? } } - Ok(13) if !is_locked && agentref.confirm_request(MessageType::Sign).await => { + Some((&13, mut r)) + if !is_locked && agentref.confirm_request(MessageType::Sign).await => + { // sign request let agent = self.agent.take().ok_or(Error::AgentFailure)?; - let (agent, signed) = self.try_sign(agent, r, writebuf).await?; + let (agent, signed) = self.try_sign(agent, &mut r, writebuf).await?; self.agent = Some(agent); if signed { return Ok(()); @@ -159,22 +160,28 @@ impl { + Some((&17, mut r)) + if !is_locked && agentref.confirm_request(MessageType::AddKeys).await => + { // add identity - if let Ok(true) = self.add_key(r, false, writebuf).await { + if let Ok(true) = self.add_key(&mut r, false, writebuf).await { } else { writebuf.push(msg::FAILURE) } } - Ok(18) if !is_locked && agentref.confirm_request(MessageType::RemoveKeys).await => { + Some((&18, mut r)) + if !is_locked && agentref.confirm_request(MessageType::RemoveKeys).await => + { // remove identity - if let Ok(true) = self.remove_identity(r) { + if let Ok(true) = self.remove_identity(&mut r) { writebuf.push(msg::SUCCESS) } else { writebuf.push(msg::FAILURE) } } - Ok(19) if !is_locked && agentref.confirm_request(MessageType::RemoveAllKeys).await => { + Some((&19, _)) + if !is_locked && agentref.confirm_request(MessageType::RemoveAllKeys).await => + { // remove all identities if let Ok(mut keys) = self.keys.0.write() { keys.clear(); @@ -183,25 +190,31 @@ impl { + Some((&22, mut r)) + if !is_locked && agentref.confirm_request(MessageType::Lock).await => + { // lock - if let Ok(()) = self.lock(r) { + if let Ok(()) = self.lock(&mut r) { writebuf.push(msg::SUCCESS) } else { writebuf.push(msg::FAILURE) } } - Ok(23) if is_locked && agentref.confirm_request(MessageType::Unlock).await => { + Some((&23, mut r)) + if is_locked && agentref.confirm_request(MessageType::Unlock).await => + { // unlock - if let Ok(true) = self.unlock(r) { + if let Ok(true) = self.unlock(&mut r) { writebuf.push(msg::SUCCESS) } else { writebuf.push(msg::FAILURE) } } - Ok(25) if !is_locked && agentref.confirm_request(MessageType::AddKeys).await => { + Some((&25, mut r)) + if !is_locked && agentref.confirm_request(MessageType::AddKeys).await => + { // add identity constrained - if let Ok(true) = self.add_key(r, true, writebuf).await { + if let Ok(true) = self.add_key(&mut r, true, writebuf).await { } else { writebuf.push(msg::FAILURE) } @@ -216,17 +229,17 @@ impl Result<(), Error> { - let password = r.read_string()?; + fn lock(&self, r: &mut R) -> Result<(), Error> { + let password = Bytes::decode(r)?; let mut lock = self.lock.0.write().or(Err(Error::AgentFailure))?; - lock.extend(password); + lock.extend(&password); Ok(()) } - fn unlock(&self, mut r: Position) -> Result { - let password = r.read_string()?; + fn unlock(&self, r: &mut R) -> Result { + let password = Bytes::decode(r)?; let mut lock = self.lock.0.write().or(Err(Error::AgentFailure))?; - if &lock[..] == password { + if lock[..] == password { lock.clear(); Ok(true) } else { @@ -234,9 +247,9 @@ impl Result { + fn remove_identity(&self, r: &mut R) -> Result { if let Ok(mut keys) = self.keys.0.write() { - if keys.remove(r.read_string()?).is_some() { + if keys.remove(&Bytes::decode(r)?.to_vec()).is_some() { Ok(true) } else { Ok(false) @@ -246,94 +259,31 @@ impl( &self, - mut r: Position<'_>, + r: &mut R, constrained: bool, writebuf: &mut CryptoVec, ) -> Result { - let pos0 = r.position; - let t = r.read_string()?; - let (blob, key) = match t { - b"ssh-ed25519" => { - let pos1 = r.position; - let concat = r.read_string()?; - let _comment = r.read_string()?; - #[allow(clippy::indexing_slicing)] // length checked before - let secret = ed25519_dalek::SigningKey::try_from( - concat.get(..32).ok_or(Error::KeyIsCorrupt)?, - ).map_err(|_| Error::KeyIsCorrupt)?; - - writebuf.push(msg::SUCCESS); + let (blob, key_pair) = { + let private_key = + ssh_key::private::PrivateKey::new(ssh_key::private::KeypairData::decode(r)?, "")?; + let _comment = String::decode(r)?; - #[allow(clippy::indexing_slicing)] // positions checked before - (self.buf[pos0..pos1].to_vec(), key::KeyPair::Ed25519(secret)) - } - #[cfg(feature = "openssl")] - b"ssh-rsa" => { - use openssl::bn::{BigNum, BigNumContext}; - use openssl::rsa::Rsa; - let n = r.read_mpint()?; - let e = r.read_mpint()?; - let d = BigNum::from_slice(r.read_mpint()?)?; - let q_inv = r.read_mpint()?; - let p = BigNum::from_slice(r.read_mpint()?)?; - let q = BigNum::from_slice(r.read_mpint()?)?; - let (dp, dq) = { - let one = BigNum::from_u32(1)?; - let p1 = p.as_ref() - one.as_ref(); - let q1 = q.as_ref() - one.as_ref(); - let mut context = BigNumContext::new()?; - let mut dp = BigNum::new()?; - let mut dq = BigNum::new()?; - dp.checked_rem(&d, &p1, &mut context)?; - dq.checked_rem(&d, &q1, &mut context)?; - (dp, dq) - }; - let _comment = r.read_string()?; - let key = Rsa::from_private_components( - BigNum::from_slice(n)?, - BigNum::from_slice(e)?, - d, - p, - q, - dp, - dq, - BigNum::from_slice(q_inv)?, - )?; - - let len0 = writebuf.len(); - writebuf.extend_ssh_string(b"ssh-rsa"); - writebuf.extend_ssh_mpint(e); - writebuf.extend_ssh_mpint(n); - - #[allow(clippy::indexing_slicing)] // length is known - let blob = writebuf[len0..].to_vec(); - writebuf.resize(len0); - writebuf.push(msg::SUCCESS); - ( - blob, - key::KeyPair::RSA { - key, - hash: SignatureHash::SHA2_256, - }, - ) - } - _ => return Ok(false), + (private_key.public_key().key_data().encoded()?, private_key) }; + writebuf.push(msg::SUCCESS); let mut w = self.keys.0.write().or(Err(Error::AgentFailure))?; let now = SystemTime::now(); if constrained { - let n = r.read_u32()?; let mut c = Vec::new(); - for _ in 0..n { - let t = r.read_byte()?; + while let Ok(t) = u8::decode(r) { if t == msg::CONSTRAIN_LIFETIME { - let seconds = r.read_u32()?; + let seconds = u32::decode(r)?; c.push(Constraint::KeyLifetime { seconds }); let blob = blob.clone(); let keys = self.keys.clone(); - tokio::spawn(async move { + russh_util::runtime::spawn(async move { sleep(Duration::from_secs(seconds as u64)).await; if let Ok(mut keys) = keys.0.write() { let delete = if let Some(&(_, time, _)) = keys.get(&blob) { @@ -352,24 +302,24 @@ impl( &self, agent: A, - mut r: Position<'_>, + r: &mut R, writebuf: &mut CryptoVec, ) -> Result<(A, bool), Error> { let mut needs_confirm = false; let key = { - let blob = r.read_string()?; + let blob = Bytes::decode(r)?; let k = self.keys.0.read().or(Err(Error::AgentFailure))?; - if let Some(&(ref key, _, ref constraints)) = k.get(blob) { + if let Some((key, _, constraints)) = k.get(&blob.to_vec()) { if constraints.iter().any(|c| *c == Constraint::Confirm) { needs_confirm = true; } @@ -379,7 +329,11 @@ impl) -> Result { +pub fn decode_secret_key(secret: &str, password: Option<&str>) -> Result { + if secret.trim().starts_with("PuTTY-User-Key-File-") { + return Ok(PrivateKey::from_ppk(secret, password.map(Into::into))?); + } let mut format = None; let secret = { let mut started = false; @@ -57,39 +62,26 @@ pub fn decode_secret_key(secret: &str, password: Option<&str>) -> Result = HEXLOWER_PERMISSIVE - .decode(l.split_at(AES_128_CBC.len()).1.as_bytes())?; - if iv_.len() != 16 { - return Err(Error::CouldNotReadKey); - } - let mut iv = [0; 16]; - iv.clone_from_slice(&iv_); - format = Some(Format::Pkcs5Encrypted(Encryption::Aes128Cbc(iv))) + let iv_: Vec = + HEXLOWER_PERMISSIVE.decode(l.split_at(AES_128_CBC.len()).1.as_bytes())?; + if iv_.len() != 16 { + return Err(Error::CouldNotReadKey); } + let mut iv = [0; 16]; + iv.clone_from_slice(&iv_); + format = Some(Format::Pkcs5Encrypted(Encryption::Aes128Cbc(iv))) } } if l == "-----BEGIN OPENSSH PRIVATE KEY-----" { started = true; format = Some(Format::Openssh); } else if l == "-----BEGIN RSA PRIVATE KEY-----" { - #[cfg(not(feature = "openssl"))] - { - return Err(Error::UnsupportedKeyType { - key_type_string: "rsa".to_owned(), - key_type_raw: "rsa".as_bytes().to_vec(), - }); - } - #[cfg(feature = "openssl")] - { - started = true; - format = Some(Format::Rsa); - } + started = true; + format = Some(Format::Rsa); } else if l == "-----BEGIN ENCRYPTED PRIVATE KEY-----" { started = true; format = Some(Format::Pkcs8Encrypted); - } else if l == "-----BEGIN PRIVATE KEY-----" { + } else if l == "-----BEGIN PRIVATE KEY-----" || l == "-----BEGIN EC PRIVATE KEY-----" { started = true; format = Some(Format::Pkcs8); } @@ -100,19 +92,28 @@ pub fn decode_secret_key(secret: &str, password: Option<&str>) -> Result decode_openssh(&secret, password), - #[cfg(feature = "openssl")] - Some(Format::Rsa) => decode_rsa(&secret), - #[cfg(feature = "openssl")] + Some(Format::Rsa) => Ok(decode_rsa_pkcs1_der(&secret)?.into()), Some(Format::Pkcs5Encrypted(enc)) => decode_pkcs5(&secret, password, enc), Some(Format::Pkcs8Encrypted) | Some(Format::Pkcs8) => { - self::pkcs8::decode_pkcs8(&secret, password.map(|x| x.as_bytes())) + let result = self::pkcs8::decode_pkcs8(&secret, password.map(|x| x.as_bytes())); + #[cfg(feature = "legacy-ed25519-pkcs8-parser")] + { + if result.is_err() { + let legacy_result = + pkcs8_legacy::decode_pkcs8(&secret, password.map(|x| x.as_bytes())); + if let Ok(key) = legacy_result { + return Ok(key); + } + } + } + result } None => Err(Error::CouldNotReadKey), } } -pub fn encode_pkcs8_pem(key: &key::KeyPair, mut w: W) -> Result<(), Error> { - let x = self::pkcs8::encode_pkcs8(key); +pub fn encode_pkcs8_pem(key: &PrivateKey, mut w: W) -> Result<(), Error> { + let x = self::pkcs8::encode_pkcs8(key)?; w.write_all(b"-----BEGIN PRIVATE KEY-----\n")?; w.write_all(BASE64_MIME.encode(&x).as_bytes())?; w.write_all(b"\n-----END PRIVATE KEY-----\n")?; @@ -120,7 +121,7 @@ pub fn encode_pkcs8_pem(key: &key::KeyPair, mut w: W) -> Result<(), Er } pub fn encode_pkcs8_pem_encrypted( - key: &key::KeyPair, + key: &PrivateKey, pass: &[u8], rounds: u32, mut w: W, @@ -132,10 +133,6 @@ pub fn encode_pkcs8_pem_encrypted( Ok(()) } -#[cfg(feature = "openssl")] -fn decode_rsa(secret: &[u8]) -> Result { - Ok(key::KeyPair::RSA { - key: Rsa::private_key_from_der(secret)?, - hash: key::SignatureHash::SHA2_256, - }) +fn decode_rsa_pkcs1_der(secret: &[u8]) -> Result { + Ok(rsa::RsaPrivateKey::from_pkcs1_der(secret)?.try_into()?) } diff --git a/russh/src/keys/format/openssh.rs b/russh/src/keys/format/openssh.rs new file mode 100644 index 00000000..cdcbb98a --- /dev/null +++ b/russh/src/keys/format/openssh.rs @@ -0,0 +1,17 @@ +use ssh_key::PrivateKey; + +use crate::keys::Error; + +/// Decode a secret key given in the OpenSSH format, deciphering it if +/// needed using the supplied password. +pub fn decode_openssh(secret: &[u8], password: Option<&str>) -> Result { + let pk = PrivateKey::from_bytes(secret)?; + if pk.is_encrypted() { + if let Some(password) = password { + return Ok(pk.decrypt(password)?); + } else { + return Err(Error::KeyIsEncrypted); + } + } + Ok(pk) +} diff --git a/russh-keys/src/format/pkcs5.rs b/russh/src/keys/format/pkcs5.rs similarity index 76% rename from russh-keys/src/format/pkcs5.rs rename to russh/src/keys/format/pkcs5.rs index 0e5a2a5e..47d4fc18 100644 --- a/russh-keys/src/format/pkcs5.rs +++ b/russh/src/keys/format/pkcs5.rs @@ -1,16 +1,16 @@ use aes::*; +use ssh_key::PrivateKey; use super::Encryption; -use crate::{key, Error}; +use crate::keys::Error; -/// Decode a secret key in the PKCS#5 format, possible deciphering it +/// Decode a secret key in the PKCS#5 format, possibly deciphering it /// using the supplied password. -#[cfg(feature = "openssl")] pub fn decode_pkcs5( secret: &[u8], password: Option<&str>, enc: Encryption, -) -> Result { +) -> Result { use aes::cipher::{BlockDecryptMut, KeyIvInit}; use block_padding::Pkcs7; @@ -25,12 +25,11 @@ pub fn decode_pkcs5( #[allow(clippy::unwrap_used)] // AES parameters are static let c = cbc::Decryptor::::new_from_slices(&md5.0, &iv[..]).unwrap(); let mut dec = secret.to_vec(); - c.decrypt_padded_mut::(&mut dec)?; - dec + c.decrypt_padded_mut::(&mut dec)?.to_vec() } Encryption::Aes256Cbc(_) => unimplemented!(), }; - super::decode_rsa(&sec) + super::decode_rsa_pkcs1_der(&sec).map(Into::into) } else { Err(Error::KeyIsEncrypted) } diff --git a/russh/src/keys/format/pkcs8.rs b/russh/src/keys/format/pkcs8.rs new file mode 100644 index 00000000..5753af9a --- /dev/null +++ b/russh/src/keys/format/pkcs8.rs @@ -0,0 +1,170 @@ +use std::convert::{TryFrom, TryInto}; + +use p256::NistP256; +use p384::NistP384; +use p521::NistP521; +use pkcs8::{AssociatedOid, EncodePrivateKey, PrivateKeyInfo, SecretDocument}; +use spki::ObjectIdentifier; +use ssh_key::private::{EcdsaKeypair, Ed25519Keypair, Ed25519PrivateKey, KeypairData}; +use ssh_key::PrivateKey; + +use crate::keys::Error; + +/// Decode a PKCS#8-encoded private key (ASN.1 or X9.62) +pub fn decode_pkcs8( + ciphertext: &[u8], + password: Option<&[u8]>, +) -> Result { + let doc = SecretDocument::try_from(ciphertext)?; + let doc = if let Some(password) = password { + doc.decode_msg::()? + .decrypt(password)? + } else { + doc + }; + + match doc.decode_msg::() { + Ok(key) => { + // X9.62 EC private key + let Some(curve) = key.parameters.and_then(|x| x.named_curve()) else { + return Err(Error::CouldNotReadKey); + }; + let kp = ec_key_data_into_keypair(curve, key)?; + Ok(PrivateKey::new(KeypairData::Ecdsa(kp), "")?) + } + Err(_) => { + // ASN.1 key + Ok(pkcs8_pki_into_keypair_data(doc.decode_msg::()?)?.try_into()?) + } + } +} + +fn pkcs8_pki_into_keypair_data(pki: PrivateKeyInfo<'_>) -> Result { + match pki.algorithm.oid { + ed25519_dalek::pkcs8::ALGORITHM_OID => { + let kpb = ed25519_dalek::pkcs8::KeypairBytes::try_from(pki)?; + let pk = Ed25519PrivateKey::from_bytes(&kpb.secret_key); + Ok(KeypairData::Ed25519(Ed25519Keypair { + public: pk.clone().into(), + private: pk, + })) + } + pkcs1::ALGORITHM_OID => { + let sk = &pkcs1::RsaPrivateKey::try_from(pki.private_key)?; + let pk = rsa::RsaPrivateKey::from_components( + rsa::BigUint::from_bytes_be(sk.modulus.as_bytes()), + rsa::BigUint::from_bytes_be(sk.public_exponent.as_bytes()), + rsa::BigUint::from_bytes_be(sk.private_exponent.as_bytes()), + vec![ + rsa::BigUint::from_bytes_be(sk.prime1.as_bytes()), + rsa::BigUint::from_bytes_be(sk.prime2.as_bytes()), + ], + )?; + Ok(KeypairData::Rsa(pk.try_into()?)) + } + sec1::ALGORITHM_OID => Ok(KeypairData::Ecdsa(ec_key_data_into_keypair( + pki.algorithm.parameters_oid()?, + pki, + )?)), + oid => Err(Error::UnknownAlgorithm(oid)), + } +} + +fn ec_key_data_into_keypair( + curve_oid: ObjectIdentifier, + private_key: K, +) -> Result +where + p256::SecretKey: TryFrom, + p384::SecretKey: TryFrom, + p521::SecretKey: TryFrom, + crate::keys::Error: From, +{ + Ok(match curve_oid { + NistP256::OID => { + let sk = p256::SecretKey::try_from(private_key)?; + EcdsaKeypair::NistP256 { + public: sk.public_key().into(), + private: sk.into(), + } + } + NistP384::OID => { + let sk = p384::SecretKey::try_from(private_key)?; + EcdsaKeypair::NistP384 { + public: sk.public_key().into(), + private: sk.into(), + } + } + NistP521::OID => { + let sk = p521::SecretKey::try_from(private_key)?; + EcdsaKeypair::NistP521 { + public: sk.public_key().into(), + private: sk.into(), + } + } + oid => return Err(Error::UnknownAlgorithm(oid)), + }) +} + +/// Encode into a password-protected PKCS#8-encoded private key. +pub fn encode_pkcs8_encrypted( + pass: &[u8], + rounds: u32, + key: &PrivateKey, +) -> Result, Error> { + let pvi_bytes = encode_pkcs8(key)?; + let pvi = PrivateKeyInfo::try_from(pvi_bytes.as_slice())?; + + use rand::RngCore; + let mut rng = rand::thread_rng(); + let mut salt = [0; 64]; + rng.fill_bytes(&mut salt); + let mut iv = [0; 16]; + rng.fill_bytes(&mut iv); + + let doc = pvi.encrypt_with_params( + pkcs5::pbes2::Parameters::pbkdf2_sha256_aes256cbc(rounds, &salt, &iv) + .map_err(|_| Error::InvalidParameters)?, + pass, + )?; + Ok(doc.as_bytes().to_vec()) +} + +/// Encode into a PKCS#8-encoded private key. +pub fn encode_pkcs8(key: &ssh_key::PrivateKey) -> Result, Error> { + let v = match key.key_data() { + ssh_key::private::KeypairData::Ed25519(ref pair) => { + let sk: ed25519_dalek::SigningKey = pair.try_into()?; + sk.to_pkcs8_der()? + } + ssh_key::private::KeypairData::Rsa(ref pair) => { + let sk: rsa::RsaPrivateKey = pair.try_into()?; + sk.to_pkcs8_der()? + } + ssh_key::private::KeypairData::Ecdsa(ref pair) => match pair { + EcdsaKeypair::NistP256 { private, .. } => { + let sk = p256::SecretKey::from_bytes(private.as_slice().into())?; + sk.to_pkcs8_der()? + } + EcdsaKeypair::NistP384 { private, .. } => { + let sk = p384::SecretKey::from_bytes(private.as_slice().into())?; + sk.to_pkcs8_der()? + } + EcdsaKeypair::NistP521 { private, .. } => { + let sk = p521::SecretKey::from_bytes(private.as_slice().into())?; + sk.to_pkcs8_der()? + } + }, + _ => { + let algo = key.algorithm(); + let kt = algo.as_str(); + return Err(Error::UnsupportedKeyType { + key_type_string: kt.into(), + key_type_raw: kt.as_bytes().into(), + }); + } + } + .as_bytes() + .to_vec(); + Ok(v) +} diff --git a/russh/src/keys/format/pkcs8_legacy.rs b/russh/src/keys/format/pkcs8_legacy.rs new file mode 100644 index 00000000..3c8e40b2 --- /dev/null +++ b/russh/src/keys/format/pkcs8_legacy.rs @@ -0,0 +1,222 @@ +use std::borrow::Cow; +use std::convert::TryFrom; + +use aes::cipher::{BlockDecryptMut, KeyIvInit}; +use aes::*; +use block_padding::Pkcs7; +use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData}; +use ssh_key::PrivateKey; +use yasna::BERReaderSeq; + +use super::Encryption; +use crate::keys::Error; + +const PBES2: &[u64] = &[1, 2, 840, 113549, 1, 5, 13]; +const ED25519: &[u64] = &[1, 3, 101, 112]; +const PBKDF2: &[u64] = &[1, 2, 840, 113549, 1, 5, 12]; +const AES256CBC: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 1, 42]; +const HMAC_SHA256: &[u64] = &[1, 2, 840, 113549, 2, 9]; + +pub fn decode_pkcs8(ciphertext: &[u8], password: Option<&[u8]>) -> Result { + let secret = if let Some(pass) = password { + Cow::Owned(yasna::parse_der(ciphertext, |reader| { + reader.read_sequence(|reader| { + // Encryption parameters + let parameters = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == PBES2 { + asn1_read_pbes2(reader) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + // Ciphertext + let ciphertext = reader.next().read_bytes()?; + Ok(parameters.map(|p| p.decrypt(pass, &ciphertext))) + }) + })???) + } else { + Cow::Borrowed(ciphertext) + }; + yasna::parse_der(&secret, |reader| { + reader.read_sequence(|reader| { + let version = reader.next().read_u64()?; + if version == 0 { + Ok(Err(Error::CouldNotReadKey)) + } else if version == 1 { + Ok(read_key_v1(reader)) + } else { + Ok(Err(Error::CouldNotReadKey)) + } + }) + })? +} + +fn read_key_v1(reader: &mut BERReaderSeq) -> Result { + let oid = reader + .next() + .read_sequence(|reader| reader.next().read_oid())?; + if oid.components().as_slice() == ED25519 { + use ed25519_dalek::SigningKey; + let secret = { + let s = yasna::parse_der(&reader.next().read_bytes()?, |reader| reader.read_bytes())?; + + s.get(..ed25519_dalek::SECRET_KEY_LENGTH) + .ok_or(Error::KeyIsCorrupt) + .and_then(|s| SigningKey::try_from(s).map_err(|_| Error::CouldNotReadKey))? + }; + // Consume the public key + reader + .next() + .read_tagged(yasna::Tag::context(1), |reader| reader.read_bitvec())?; + + let pk = Ed25519PrivateKey::from(&secret); + Ok(PrivateKey::new( + KeypairData::Ed25519(Ed25519Keypair { + public: pk.clone().into(), + private: pk, + }), + "", + )?) + } else { + Err(Error::CouldNotReadKey) + } +} + +#[derive(Debug)] +enum Key { + K128([u8; 16]), + K256([u8; 32]), +} + +impl std::ops::Deref for Key { + type Target = [u8]; + fn deref(&self) -> &[u8] { + match *self { + Key::K128(ref k) => k, + Key::K256(ref k) => k, + } + } +} + +impl std::ops::DerefMut for Key { + fn deref_mut(&mut self) -> &mut [u8] { + match *self { + Key::K128(ref mut k) => k, + Key::K256(ref mut k) => k, + } + } +} + +enum Algorithms { + Pbes2(KeyDerivation, Encryption), +} + +impl Algorithms { + fn decrypt(&self, password: &[u8], cipher: &[u8]) -> Result, Error> { + match *self { + Algorithms::Pbes2(ref der, ref enc) => { + let mut key = enc.key(); + der.derive(password, &mut key)?; + let out = enc.decrypt(&key, cipher)?; + Ok(out) + } + } + } +} + +impl Encryption { + fn key(&self) -> Key { + match *self { + Encryption::Aes128Cbc(_) => Key::K128([0; 16]), + Encryption::Aes256Cbc(_) => Key::K256([0; 32]), + } + } + + fn decrypt(&self, key: &[u8], ciphertext: &[u8]) -> Result, Error> { + match *self { + Encryption::Aes128Cbc(ref iv) => { + #[allow(clippy::unwrap_used)] // parameters are static + let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); + let mut dec = ciphertext.to_vec(); + Ok(c.decrypt_padded_mut::(&mut dec)?.into()) + } + Encryption::Aes256Cbc(ref iv) => { + #[allow(clippy::unwrap_used)] // parameters are static + let c = cbc::Decryptor::::new_from_slices(key, iv).unwrap(); + let mut dec = ciphertext.to_vec(); + Ok(c.decrypt_padded_mut::(&mut dec)?.into()) + } + } + } +} + +enum KeyDerivation { + Pbkdf2 { salt: Vec, rounds: u64 }, +} + +impl KeyDerivation { + fn derive(&self, password: &[u8], key: &mut [u8]) -> Result<(), Error> { + match *self { + KeyDerivation::Pbkdf2 { ref salt, rounds } => { + pbkdf2::pbkdf2::>(password, salt, rounds as u32, key) + .map_err(|_| Error::InvalidParameters) + // pbkdf2_hmac(password, salt, rounds as usize, digest, key)? + } + } + } +} +fn asn1_read_pbes2( + reader: &mut yasna::BERReaderSeq, +) -> Result, yasna::ASN1Error> { + reader.next().read_sequence(|reader| { + // PBES2 has two components. + // 1. Key generation algorithm + let keygen = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == PBKDF2 { + asn1_read_pbkdf2(reader) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + // 2. Encryption algorithm. + let algorithm = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == AES256CBC { + asn1_read_aes256cbc(reader) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + Ok(keygen.and_then(|keygen| algorithm.map(|algo| Algorithms::Pbes2(keygen, algo)))) + }) +} + +fn asn1_read_pbkdf2( + reader: &mut yasna::BERReaderSeq, +) -> Result, yasna::ASN1Error> { + reader.next().read_sequence(|reader| { + let salt = reader.next().read_bytes()?; + let rounds = reader.next().read_u64()?; + let digest = reader.next().read_sequence(|reader| { + let oid = reader.next().read_oid()?; + if oid.components().as_slice() == HMAC_SHA256 { + reader.next().read_null()?; + Ok(Ok(())) + } else { + Ok(Err(Error::InvalidParameters)) + } + })?; + Ok(digest.map(|()| KeyDerivation::Pbkdf2 { salt, rounds })) + }) +} + +fn asn1_read_aes256cbc( + reader: &mut yasna::BERReaderSeq, +) -> Result, yasna::ASN1Error> { + let iv = reader.next().read_bytes()?; + let mut i = [0; 16]; + i.clone_from_slice(&iv); + Ok(Ok(Encryption::Aes256Cbc(i))) +} diff --git a/russh/src/keys/format/tests.rs b/russh/src/keys/format/tests.rs new file mode 100644 index 00000000..245c1ffb --- /dev/null +++ b/russh/src/keys/format/tests.rs @@ -0,0 +1,12 @@ +use super::decode_secret_key; + +#[test] +fn test_ec_private_key() { + let key = r#"-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDBNK0jwKqqf8zkM+Z2l++9r8bzdTS/XCoB4N1J07dPxpByyJyGbhvIy +1kLvY2gIvlmgBwYFK4EEACKhZANiAAQvPxAK2RhvH/k5inDa9oMxUZPvvb9fq8G3 +9dKW1tS+ywhejnKeu/48HXAXgx2g6qMJjEPpcTy/DaYm12r3GTaRzOBQmxSItStk +lpQg5vf23Fc9fFrQ9AnQKrb1dgTkoxQ= +-----END EC PRIVATE KEY-----"#; + decode_secret_key(&key, None).unwrap(); +} diff --git a/russh/src/keys/key.rs b/russh/src/keys/key.rs new file mode 100644 index 00000000..81f4ccc9 --- /dev/null +++ b/russh/src/keys/key.rs @@ -0,0 +1,121 @@ +// Copyright 2016 Pierre-Étienne Meunier +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +use ssh_encoding::Decode; +use ssh_key::public::KeyData; +use ssh_key::{Algorithm, EcdsaCurve, HashAlg, PublicKey}; + +use crate::keys::Error; + +pub trait PublicKeyExt { + fn decode(bytes: &[u8]) -> Result; +} + +impl PublicKeyExt for PublicKey { + fn decode(mut bytes: &[u8]) -> Result { + let key = KeyData::decode(&mut bytes)?; + Ok(PublicKey::new(key, "")) + } +} + +#[doc(hidden)] +pub trait Verify { + fn verify_client_auth(&self, buffer: &[u8], sig: &[u8]) -> bool; + fn verify_server_auth(&self, buffer: &[u8], sig: &[u8]) -> bool; +} + +/// Parse a public key from a byte slice. +pub fn parse_public_key(mut p: &[u8]) -> Result { + Ok(ssh_key::public::KeyData::decode(&mut p)?.into()) +} + +/// Obtain a cryptographic-safe random number generator. +pub fn safe_rng() -> impl rand::CryptoRng + rand::RngCore { + rand::thread_rng() +} + +mod private_key_with_hash_alg { + use std::ops::Deref; + use std::sync::Arc; + + use ssh_key::Algorithm; + + use crate::helpers::AlgorithmExt; + + /// Helper structure to correlate a key and (in case of RSA) a hash algorithm. + /// Only used for authentication, not key storage as RSA keys do not inherently + /// have a hash algorithm associated with them. + #[derive(Clone, Debug)] + pub struct PrivateKeyWithHashAlg { + key: Arc, + hash_alg: Option, + } + + impl PrivateKeyWithHashAlg { + /// Direct constructor. + /// + /// For RSA, passing `None` is mapped to the legacy `sha-rsa` (SHA-1). + /// For other keys, `hash_alg` is ignored. + pub fn new( + key: Arc, + mut hash_alg: Option, + ) -> Self { + if !key.algorithm().is_rsa() { + hash_alg = None; + } + Self { key, hash_alg } + } + + pub fn algorithm(&self) -> Algorithm { + self.key.algorithm().with_hash_alg(self.hash_alg) + } + + pub fn hash_alg(&self) -> Option { + self.hash_alg + } + } + + impl Deref for PrivateKeyWithHashAlg { + type Target = crate::keys::PrivateKey; + + fn deref(&self) -> &Self::Target { + &self.key + } + } +} + +pub use private_key_with_hash_alg::PrivateKeyWithHashAlg; + +pub const ALL_KEY_TYPES: &[Algorithm] = &[ + Algorithm::Dsa, + Algorithm::Ecdsa { + curve: EcdsaCurve::NistP256, + }, + Algorithm::Ecdsa { + curve: EcdsaCurve::NistP384, + }, + Algorithm::Ecdsa { + curve: EcdsaCurve::NistP521, + }, + Algorithm::Ed25519, + Algorithm::Rsa { hash: None }, + Algorithm::Rsa { + hash: Some(HashAlg::Sha256), + }, + Algorithm::Rsa { + hash: Some(HashAlg::Sha512), + }, + Algorithm::SkEcdsaSha2NistP256, + Algorithm::SkEd25519, +]; diff --git a/russh/src/keys/known_hosts.rs b/russh/src/keys/known_hosts.rs new file mode 100644 index 00000000..879cbfce --- /dev/null +++ b/russh/src/keys/known_hosts.rs @@ -0,0 +1,243 @@ +use std::borrow::Cow; +use std::fs::{File, OpenOptions}; +use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use data_encoding::BASE64_MIME; +use hmac::{Hmac, Mac}; +use log::debug; +use sha1::Sha1; + +use crate::keys::Error; + +/// Check whether the host is known, from its standard location. +pub fn check_known_hosts( + host: &str, + port: u16, + pubkey: &ssh_key::PublicKey, +) -> Result { + check_known_hosts_path(host, port, pubkey, known_hosts_path()?) +} + +/// Check that a server key matches the one recorded in file `path`. +pub fn check_known_hosts_path>( + host: &str, + port: u16, + pubkey: &ssh_key::PublicKey, + path: P, +) -> Result { + let check = known_host_keys_path(host, port, path)? + .into_iter() + .map(|(line, recorded)| { + match ( + pubkey.algorithm() == recorded.algorithm(), + *pubkey == recorded, + ) { + (true, true) => Ok(true), + (true, false) => Err(Error::KeyChanged { line }), + _ => Ok(false), + } + }) + // If any Err was returned, we stop here + .collect::, Error>>()? + .into_iter() + // Now we check the results for a match + .any(|x| x); + + Ok(check) +} + +#[cfg(target_os = "windows")] +fn known_hosts_path() -> Result { + if let Some(home_dir) = home::home_dir() { + Ok(home_dir.join("ssh").join("known_hosts")) + } else { + Err(Error::NoHomeDir) + } +} + +#[cfg(not(target_os = "windows"))] +fn known_hosts_path() -> Result { + if let Some(home_dir) = home::home_dir() { + Ok(home_dir.join(".ssh").join("known_hosts")) + } else { + Err(Error::NoHomeDir) + } +} + +/// Get the server key that matches the one recorded in the user's known_hosts file. +pub fn known_host_keys(host: &str, port: u16) -> Result, Error> { + known_host_keys_path(host, port, known_hosts_path()?) +} + +/// Get the server key that matches the one recorded in `path`. +pub fn known_host_keys_path>( + host: &str, + port: u16, + path: P, +) -> Result, Error> { + use crate::keys::parse_public_key_base64; + + let mut f = if let Ok(f) = File::open(path) { + BufReader::new(f) + } else { + return Ok(vec![]); + }; + let mut buffer = String::new(); + + let host_port = if port == 22 { + Cow::Borrowed(host) + } else { + Cow::Owned(format!("[{}]:{}", host, port)) + }; + debug!("host_port = {:?}", host_port); + let mut line = 1; + let mut matches = vec![]; + while f.read_line(&mut buffer)? > 0 { + { + if buffer.as_bytes().first() == Some(&b'#') { + buffer.clear(); + continue; + } + debug!("line = {:?}", buffer); + let mut s = buffer.split(' '); + let hosts = s.next(); + let _ = s.next(); + let key = s.next(); + if let (Some(h), Some(k)) = (hosts, key) { + debug!("{:?} {:?}", h, k); + if match_hostname(&host_port, h) { + matches.push((line, parse_public_key_base64(k)?)); + } + } + } + buffer.clear(); + line += 1; + } + Ok(matches) +} + +fn match_hostname(host: &str, pattern: &str) -> bool { + for entry in pattern.split(',') { + if entry.starts_with("|1|") { + let mut parts = entry.split('|').skip(2); + let Some(Ok(salt)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { + continue; + }; + let Some(Ok(hash)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { + continue; + }; + if let Ok(hmac) = Hmac::::new_from_slice(&salt) { + if hmac.chain_update(host).verify_slice(&hash).is_ok() { + return true; + } + } + } else if host == entry { + return true; + } + } + false +} + +/// Record a host's public key into the user's known_hosts file. +pub fn learn_known_hosts(host: &str, port: u16, pubkey: &ssh_key::PublicKey) -> Result<(), Error> { + learn_known_hosts_path(host, port, pubkey, known_hosts_path()?) +} + +/// Record a host's public key into a nonstandard location. +pub fn learn_known_hosts_path>( + host: &str, + port: u16, + pubkey: &ssh_key::PublicKey, + path: P, +) -> Result<(), Error> { + if let Some(parent) = path.as_ref().parent() { + std::fs::create_dir_all(parent)? + } + let mut file = OpenOptions::new() + .read(true) + .append(true) + .create(true) + .open(path)?; + + // Test whether the known_hosts file ends with a \n + let mut buf = [0; 1]; + let mut ends_in_newline = false; + if file.seek(SeekFrom::End(-1)).is_ok() { + file.read_exact(&mut buf)?; + ends_in_newline = buf[0] == b'\n'; + } + + // Write the key. + file.seek(SeekFrom::End(0))?; + let mut file = std::io::BufWriter::new(file); + if !ends_in_newline { + file.write_all(b"\n")?; + } + if port != 22 { + write!(file, "[{}]:{} ", host, port)? + } else { + write!(file, "{} ", host)? + } + file.write_all(pubkey.to_openssh()?.as_bytes())?; + file.write_all(b"\n")?; + Ok(()) +} + +#[cfg(test)] +mod test { + use std::fs::File; + + use super::*; + use crate::keys::parse_public_key_base64; + + #[test] + fn test_check_known_hosts() { + env_logger::try_init().unwrap_or(()); + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("known_hosts"); + { + let mut f = File::create(&path).unwrap(); + f.write_all(b"[localhost]:13265 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ\n").unwrap(); + f.write_all(b"#pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"|1|O33ESRMWPVkMYIwJ1Uw+n877jTo=|nuuC5vEqXlEZ/8BXQR7m619W6Ak= ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF\n").unwrap(); + } + + // Valid key, non-standard port. + let host = "localhost"; + let port = 13265; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + + // Valid key, hashed. + let host = "example.com"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + + // Valid key, several hosts, port 22 + let host = "pijul.org"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + + // Now with the key in a comment above, check that it's not recognized + let host = "pijul.org"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).is_err()); + } +} diff --git a/russh-keys/src/lib.rs b/russh/src/keys/mod.rs similarity index 58% rename from russh-keys/src/lib.rs rename to russh/src/keys/mod.rs index 09287aa0..f5889e61 100644 --- a/russh-keys/src/lib.rs +++ b/russh/src/keys/mod.rs @@ -1,39 +1,32 @@ -#![deny(trivial_casts, unstable_features, unused_import_braces)] -#![deny( - clippy::unwrap_used, - clippy::expect_used, - clippy::indexing_slicing, - clippy::panic -)] //! This crate contains methods to deal with SSH keys, as defined in //! crate Russh. This includes in particular various functions for //! opening key files, deciphering encrypted keys, and dealing with //! agents. //! -//! The following example (which uses the `openssl` feature) shows how -//! to do all these in a single example: start and SSH agent server, -//! connect to it with a client, decipher an encrypted private key -//! (the password is `b"blabla"`), send it to the agent, and ask the -//! agent to sign a piece of data (`b"Please sign this", below). +//! The following example shows how to do all these in a single example: +//! start and SSH agent server, connect to it with a client, decipher +//! an encrypted private key (the password is `b"blabla"`), send it to +//! the agent, and ask the agent to sign a piece of data +//! (`b"Please sign this"`, below). //! //!``` -//! use russh_keys::*; +//! use russh::keys::*; //! use futures::Future; //! //! #[derive(Clone)] //! struct X{} //! impl agent::server::Agent for X { -//! fn confirm(self, _: std::sync::Arc) -> Box + Send + Unpin> { +//! fn confirm(self, _: std::sync::Arc) -> Box + Send + Unpin> { //! Box::new(futures::future::ready((self, true))) //! } //! } //! //! const PKCS8_ENCRYPTED: &'static str = "-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIIFLTBXBgkqhkiG9w0BBQ0wSjApBgkqhkiG9w0BBQwwHAQITo1O0b8YrS0CAggA\nMAwGCCqGSIb3DQIJBQAwHQYJYIZIAWUDBAEqBBBtLH4T1KOfo1GGr7salhR8BIIE\n0KN9ednYwcTGSX3hg7fROhTw7JAJ1D4IdT1fsoGeNu2BFuIgF3cthGHe6S5zceI2\nMpkfwvHbsOlDFWMUIAb/VY8/iYxhNmd5J6NStMYRC9NC0fVzOmrJqE1wITqxtORx\nIkzqkgFUbaaiFFQPepsh5CvQfAgGEWV329SsTOKIgyTj97RxfZIKA+TR5J5g2dJY\nj346SvHhSxJ4Jc0asccgMb0HGh9UUDzDSql0OIdbnZW5KzYJPOx+aDqnpbz7UzY/\nP8N0w/pEiGmkdkNyvGsdttcjFpOWlLnLDhtLx8dDwi/sbEYHtpMzsYC9jPn3hnds\nTcotqjoSZ31O6rJD4z18FOQb4iZs3MohwEdDd9XKblTfYKM62aQJWH6cVQcg+1C7\njX9l2wmyK26Tkkl5Qg/qSfzrCveke5muZgZkFwL0GCcgPJ8RixSB4GOdSMa/hAMU\nkvFAtoV2GluIgmSe1pG5cNMhurxM1dPPf4WnD+9hkFFSsMkTAuxDZIdDk3FA8zof\nYhv0ZTfvT6V+vgH3Hv7Tqcxomy5Qr3tj5vvAqqDU6k7fC4FvkxDh2mG5ovWvc4Nb\nXv8sed0LGpYitIOMldu6650LoZAqJVv5N4cAA2Edqldf7S2Iz1QnA/usXkQd4tLa\nZ80+sDNv9eCVkfaJ6kOVLk/ghLdXWJYRLenfQZtVUXrPkaPpNXgD0dlaTN8KuvML\nUw/UGa+4ybnPsdVflI0YkJKbxouhp4iB4S5ACAwqHVmsH5GRnujf10qLoS7RjDAl\no/wSHxdT9BECp7TT8ID65u2mlJvH13iJbktPczGXt07nBiBse6OxsClfBtHkRLzE\nQF6UMEXsJnIIMRfrZQnduC8FUOkfPOSXc8r9SeZ3GhfbV/DmWZvFPCpjzKYPsM5+\nN8Bw/iZ7NIH4xzNOgwdp5BzjH9hRtCt4sUKVVlWfEDtTnkHNOusQGKu7HkBF87YZ\nRN/Nd3gvHob668JOcGchcOzcsqsgzhGMD8+G9T9oZkFCYtwUXQU2XjMN0R4VtQgZ\nrAxWyQau9xXMGyDC67gQ5xSn+oqMK0HmoW8jh2LG/cUowHFAkUxdzGadnjGhMOI2\nzwNJPIjF93eDF/+zW5E1l0iGdiYyHkJbWSvcCuvTwma9FIDB45vOh5mSR+YjjSM5\nnq3THSWNi7Cxqz12Q1+i9pz92T2myYKBBtu1WDh+2KOn5DUkfEadY5SsIu/Rb7ub\n5FBihk2RN3y/iZk+36I69HgGg1OElYjps3D+A9AjVby10zxxLAz8U28YqJZm4wA/\nT0HLxBiVw+rsHmLP79KvsT2+b4Diqih+VTXouPWC/W+lELYKSlqnJCat77IxgM9e\nYIhzD47OgWl33GJ/R10+RDoDvY4koYE+V5NLglEhbwjloo9Ryv5ywBJNS7mfXMsK\n/uf+l2AscZTZ1mhtL38efTQCIRjyFHc3V31DI0UdETADi+/Omz+bXu0D5VvX+7c6\nb1iVZKpJw8KUjzeUV8yOZhvGu3LrQbhkTPVYL555iP1KN0Eya88ra+FUKMwLgjYr\nJkUx4iad4dTsGPodwEP/Y9oX/Qk3ZQr+REZ8lg6IBoKKqqrQeBJ9gkm1jfKE6Xkc\nCog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux\n-----END ENCRYPTED PRIVATE KEY-----\n"; //! -//! #[cfg(all(unix, feature = "openssl"))] +//! #[cfg(unix)] //! fn main() { //! env_logger::try_init().unwrap_or(()); -//! let dir = tempdir::TempDir::new("russh").unwrap(); +//! let dir = tempfile::tempdir().unwrap(); //! let agent_path = dir.path().join("agent"); //! //! let mut core = tokio::runtime::Runtime::new().unwrap(); @@ -42,49 +35,58 @@ //! core.spawn(async move { //! let mut listener = tokio::net::UnixListener::bind(&agent_path_) //! .unwrap(); -//! russh_keys::agent::server::serve(tokio_stream::wrappers::UnixListenerStream::new(listener), X {}).await +//! russh::keys::agent::server::serve(tokio_stream::wrappers::UnixListenerStream::new(listener), X {}).await //! }); //! let key = decode_secret_key(PKCS8_ENCRYPTED, Some("blabla")).unwrap(); -//! let public = key.clone_public_key().unwrap(); +//! let public = key.public_key().clone(); //! core.block_on(async move { //! let stream = tokio::net::UnixStream::connect(&agent_path).await?; //! let mut client = agent::client::AgentClient::connect(stream); //! client.add_identity(&key, &[agent::Constraint::KeyLifetime { seconds: 60 }]).await?; //! client.request_identities().await?; //! let buf = b"signed message"; -//! let sig = client.sign_request(&public, russh_cryptovec::CryptoVec::from_slice(&buf[..])).await.1.unwrap(); +//! let sig = client.sign_request(&public, None, russh_cryptovec::CryptoVec::from_slice(&buf[..])).await.unwrap(); //! // Here, `sig` is encoded in a format usable internally by the SSH protocol. //! Ok::<(), Error>(()) //! }).unwrap() //! } //! -//! #[cfg(any(not(unix), not(feature = "openssl")))] +//! #[cfg(not(unix))] //! fn main() {} //! //! ``` -use std::borrow::Cow; -use std::fs::{File, OpenOptions}; -use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; +use std::fs::File; +use std::io::Read; use std::path::Path; +use std::string::FromUtf8Error; use aes::cipher::block_padding::UnpadError; use aes::cipher::inout::PadError; -use byteorder::{BigEndian, WriteBytesExt}; use data_encoding::BASE64_MIME; -use log::debug; use thiserror::Error; -pub mod encoding; +use crate::helpers::EncodedExt; + pub mod key; -pub mod signature; +pub use key::PrivateKeyWithHashAlg; mod format; pub use format::*; +// Reexports +pub use signature; +pub use ssh_encoding; +pub use ssh_key::{self, Algorithm, Certificate, EcdsaCurve, HashAlg, PrivateKey, PublicKey}; -/// A module to write SSH agent. +/// OpenSSH agent protocol implementation pub mod agent; +#[cfg(not(target_arch = "wasm32"))] +pub mod known_hosts; + +#[cfg(not(target_arch = "wasm32"))] +pub use known_hosts::{check_known_hosts, check_known_hosts_path}; + #[derive(Debug, Error)] pub enum Error { /// The key could not be read, for an unknown reason @@ -99,6 +101,9 @@ pub enum Error { /// The type of the key is unsupported #[error("Invalid Ed25519 key data")] Ed25519KeyError(#[from] ed25519_dalek::SignatureError), + /// The type of the key is unsupported + #[error("Invalid ECDSA key data")] + EcdsaKeyError(#[from] p256::elliptic_curve::Error), /// The key is encrypted (should supply a password?) #[error("The key is encrypted")] KeyIsEncrypted, @@ -112,14 +117,18 @@ pub enum Error { #[error("The server key changed at line {}", line)] KeyChanged { line: usize }, /// The key uses an unsupported algorithm - #[error("Unknown key algorithm")] - UnknownAlgorithm(yasna::models::ObjectIdentifier), + #[error("Unknown key algorithm: {0}")] + UnknownAlgorithm(::pkcs8::ObjectIdentifier), /// Index out of bounds #[error("Index out of bounds")] IndexOutOfBounds, /// Unknown signature type #[error("Unknown signature type: {}", sig_type)] UnknownSignatureType { sig_type: String }, + #[error("Invalid signature")] + InvalidSignature, + #[error("Invalid parameters")] + InvalidParameters, /// Agent protocol error #[error("Agent protocol error")] AgentProtocolError, @@ -128,9 +137,8 @@ pub enum Error { #[error(transparent)] IO(#[from] std::io::Error), - #[cfg(feature = "openssl")] - #[error(transparent)] - Openssl(#[from] openssl::error::ErrorStack), + #[error("Rsa: {0}")] + Rsa(#[from] rsa::Error), #[error(transparent)] Pad(#[from] PadError), @@ -140,8 +148,22 @@ pub enum Error { #[error("Base64 decoding error: {0}")] Decode(#[from] data_encoding::DecodeError), - #[error("ASN1 decoding error: {0}")] - ASN1(yasna::ASN1Error), + #[error("Der: {0}")] + Der(#[from] der::Error), + #[error("Spki: {0}")] + Spki(#[from] spki::Error), + #[error("Pkcs1: {0}")] + Pkcs1(#[from] pkcs1::Error), + #[error("Pkcs8: {0}")] + Pkcs8(#[from] ::pkcs8::Error), + #[error("Sec1: {0}")] + Sec1(#[from] sec1::Error), + + #[error("SshKey: {0}")] + SshKey(#[from] ssh_key::Error), + #[error("SshEncoding: {0}")] + SshEncoding(#[from] ssh_encoding::Error), + #[error("Environment variable `{0}` not found")] EnvVar(&'static str), #[error( @@ -149,23 +171,32 @@ pub enum Error { points to a nonexistent file or directory." )] BadAuthSock, + + #[error(transparent)] + Utf8(#[from] FromUtf8Error), + + #[error("ASN1 decoding error: {0}")] + #[cfg(feature = "legacy-ed25519-pkcs8-parser")] + LegacyASN1(::yasna::ASN1Error), + + #[cfg(windows)] + #[error("Pageant: {0}")] + Pageant(#[from] pageant::Error), } +#[cfg(feature = "legacy-ed25519-pkcs8-parser")] impl From for Error { fn from(e: yasna::ASN1Error) -> Error { - Error::ASN1(e) + Error::LegacyASN1(e) } } -const KEYTYPE_ED25519: &[u8] = b"ssh-ed25519"; -const KEYTYPE_RSA: &[u8] = b"ssh-rsa"; - -/// Load a public key from a file. Ed25519 and RSA keys are supported. +/// Load a public key from a file. Ed25519, EC-DSA and RSA keys are supported. /// /// ``` -/// russh_keys::load_public_key("../files/id_ed25519.pub").unwrap(); +/// russh::keys::load_public_key("../files/id_ed25519.pub").unwrap(); /// ``` -pub fn load_public_key>(path: P) -> Result { +pub fn load_public_key>(path: P) -> Result { let mut pubkey = String::new(); let mut file = File::open(path.as_ref())?; file.read_to_string(&mut pubkey)?; @@ -183,15 +214,11 @@ pub fn load_public_key>(path: P) -> Result /// as `ssh-ed25519 AAAAC3N...`). /// /// ``` -/// russh_keys::parse_public_key_base64("AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ").is_ok(); +/// russh::keys::parse_public_key_base64("AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ").is_ok(); /// ``` -pub fn parse_public_key_base64(key: &str) -> Result { +pub fn parse_public_key_base64(key: &str) -> Result { let base = BASE64_MIME.decode(key.as_bytes())?; - key::parse_public_key( - &base, - #[cfg(feature = "openssl")] - None, - ) + key::parse_public_key(&base) } pub trait PublicKeyBase64 { @@ -205,239 +232,55 @@ pub trait PublicKeyBase64 { } } -impl PublicKeyBase64 for key::PublicKey { +impl PublicKeyBase64 for ssh_key::PublicKey { fn public_key_bytes(&self) -> Vec { - let mut s = Vec::new(); - match *self { - key::PublicKey::Ed25519(ref publickey) => { - let name = b"ssh-ed25519"; - #[allow(clippy::unwrap_used)] // Vec<>.write can't fail - s.write_u32::(name.len() as u32).unwrap(); - s.extend_from_slice(name); - #[allow(clippy::unwrap_used)] // Vec<>.write can't fail - s.write_u32::(publickey.as_bytes().len() as u32) - .unwrap(); - s.extend_from_slice(publickey.as_bytes()); - } - #[cfg(feature = "openssl")] - key::PublicKey::RSA { ref key, .. } => { - use encoding::Encoding; - let name = b"ssh-rsa"; - #[allow(clippy::unwrap_used)] // Vec<>.write_all can't fail - s.write_u32::(name.len() as u32).unwrap(); - s.extend_from_slice(name); - #[allow(clippy::unwrap_used)] // TODO check - s.extend_ssh_mpint(&key.0.rsa().unwrap().e().to_vec()); - #[allow(clippy::unwrap_used)] // TODO check - s.extend_ssh_mpint(&key.0.rsa().unwrap().n().to_vec()); - } - } - s + self.key_data().encoded().unwrap_or_default() } } -impl PublicKeyBase64 for key::KeyPair { +impl PublicKeyBase64 for PrivateKey { fn public_key_bytes(&self) -> Vec { - let name = self.name().as_bytes(); - let mut s = Vec::new(); - #[allow(clippy::unwrap_used)] // Vec<>.write_all can't fail - s.write_u32::(name.len() as u32).unwrap(); - s.extend_from_slice(name); - match *self { - key::KeyPair::Ed25519(ref key) => { - let public = key.verifying_key().to_bytes(); - #[allow(clippy::unwrap_used)] // Vec<>.write can't fail - s.write_u32::(public.len() as u32).unwrap(); - s.extend_from_slice(public.as_slice()); - } - #[cfg(feature = "openssl")] - key::KeyPair::RSA { ref key, .. } => { - use encoding::Encoding; - s.extend_ssh_mpint(&key.e().to_vec()); - s.extend_ssh_mpint(&key.n().to_vec()); - } - } - s + self.public_key().public_key_bytes() } } -/// Write a public key onto the provided `Write`, encoded in base-64. -pub fn write_public_key_base64( - mut w: W, - publickey: &key::PublicKey, -) -> Result<(), Error> { - let pk = publickey.public_key_base64(); - writeln!(w, "{} {}", publickey.name(), pk)?; - Ok(()) -} - /// Load a secret key, deciphering it with the supplied password if necessary. pub fn load_secret_key>( secret_: P, password: Option<&str>, -) -> Result { +) -> Result { let mut secret_file = std::fs::File::open(secret_)?; let mut secret = String::new(); secret_file.read_to_string(&mut secret)?; decode_secret_key(&secret, password) } +/// Load a openssh certificate +pub fn load_openssh_certificate>(cert_: P) -> Result { + let mut cert_file = std::fs::File::open(cert_)?; + let mut cert = String::new(); + cert_file.read_to_string(&mut cert)?; + + Certificate::from_openssh(&cert) +} + fn is_base64_char(c: char) -> bool { - ('a'..='z').contains(&c) - || ('A'..='Z').contains(&c) - || ('0'..='9').contains(&c) + c.is_ascii_lowercase() + || c.is_ascii_uppercase() + || c.is_ascii_digit() || c == '/' || c == '+' || c == '=' } -/// Record a host's public key into a nonstandard location. -pub fn learn_known_hosts_path>( - host: &str, - port: u16, - pubkey: &key::PublicKey, - path: P, -) -> Result<(), Error> { - if let Some(parent) = path.as_ref().parent() { - std::fs::create_dir_all(parent)? - } - let mut file = OpenOptions::new() - .read(true) - .append(true) - .create(true) - .open(path)?; - - // Test whether the known_hosts file ends with a \n - let mut buf = [0; 1]; - let mut ends_in_newline = false; - if file.seek(SeekFrom::End(-1)).is_ok() { - file.read_exact(&mut buf)?; - ends_in_newline = buf[0] == b'\n'; - } - - // Write the key. - file.seek(SeekFrom::End(0))?; - let mut file = std::io::BufWriter::new(file); - if !ends_in_newline { - file.write_all(b"\n")?; - } - if port != 22 { - write!(file, "[{}]:{} ", host, port)? - } else { - write!(file, "{} ", host)? - } - write_public_key_base64(&mut file, pubkey)?; - file.write_all(b"\n")?; - Ok(()) -} - -/// Check that a server key matches the one recorded in file `path`. -pub fn check_known_hosts_path>( - host: &str, - port: u16, - pubkey: &key::PublicKey, - path: P, -) -> Result { - let mut f = if let Ok(f) = File::open(path) { - BufReader::new(f) - } else { - return Ok(false); - }; - let mut buffer = String::new(); - - let host_port = if port == 22 { - Cow::Borrowed(host) - } else { - Cow::Owned(format!("[{}]:{}", host, port)) - }; - debug!("host_port = {:?}", host_port); - let mut line = 1; - while f.read_line(&mut buffer)? > 0 { - { - if buffer.as_bytes().first() == Some(&b'#') { - buffer.clear(); - continue; - } - debug!("line = {:?}", buffer); - let mut s = buffer.split(' '); - let hosts = s.next(); - let _ = s.next(); - let key = s.next(); - if let (Some(h), Some(k)) = (hosts, key) { - debug!("{:?} {:?}", h, k); - let host_matches = h.split(',').any(|x| x == host_port); - if host_matches { - if &parse_public_key_base64(k)? == pubkey { - return Ok(true); - } else { - return Err(Error::KeyChanged { line }); - } - } - } - } - buffer.clear(); - line += 1; - } - Ok(false) -} - -/// Record a host's public key into the user's known_hosts file. -#[cfg(target_os = "windows")] -pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push("ssh"); - known_host_file.push("known_hosts"); - learn_known_hosts_path(host, port, pubkey, &known_host_file) - } else { - Err(Error::NoHomeDir) - } -} - -/// Record a host's public key into the user's known_hosts file. -#[cfg(not(target_os = "windows"))] -pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push(".ssh"); - known_host_file.push("known_hosts"); - learn_known_hosts_path(host, port, pubkey, &known_host_file) - } else { - Err(Error::NoHomeDir) - } -} - -/// Check whether the host is known, from its standard location. -#[cfg(target_os = "windows")] -pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push("ssh"); - known_host_file.push("known_hosts"); - check_known_hosts_path(host, port, pubkey, &known_host_file) - } else { - Err(Error::NoHomeDir.into()) - } -} - -/// Check whether the host is known, from its standard location. -#[cfg(not(target_os = "windows"))] -pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { - if let Some(mut known_host_file) = dirs::home_dir() { - known_host_file.push(".ssh"); - known_host_file.push("known_hosts"); - check_known_hosts_path(host, port, pubkey, &known_host_file) - } else { - Err(Error::NoHomeDir) - } -} - #[cfg(test)] mod test { - use std::fs::File; - use std::io::Write; - #[cfg(feature = "openssl")] + #[cfg(unix)] use futures::Future; use super::*; + use crate::keys::key::PublicKeyExt; const ED25519_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jYmMAAAAGYmNyeXB0AAAAGAAAABDLGyfA39 @@ -458,7 +301,6 @@ dP3jryYgvsCIBAA5jMWSjrmnOTXhidqcOy4xYCrAttzSnZ/cUadfBenL+DQq6neffw7j8r sJWR7W+cGvJ/vLsw== -----END OPENSSH PRIVATE KEY-----"; - #[cfg(feature = "openssl")] const RSA_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn NhAAAAAwEAAQAAAQEAuSvQ9m76zhRB4m0BUKPf17lwccj7KQ1Qtse63AOqP/VYItqEH8un @@ -499,66 +341,167 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== decode_secret_key(ED25519_AESCTR_KEY, Some("test")).unwrap(); } + // Key from RFC 8410 Section 10.3. This is a key using PrivateKeyInfo structure. + const RFC8410_ED25519_PRIVATE_ONLY_KEY: &str = "-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC +-----END PRIVATE KEY-----"; + + #[test] + fn test_decode_rfc8410_ed25519_private_only_key() { + env_logger::try_init().unwrap_or(()); + assert!( + decode_secret_key(RFC8410_ED25519_PRIVATE_ONLY_KEY, None) + .unwrap() + .algorithm() + == ssh_key::Algorithm::Ed25519, + ); + // We always encode public key, skip test_decode_encode_symmetry. + } + + // Key from RFC 8410 Section 10.3. This is a key using OneAsymmetricKey structure. + const RFC8410_ED25519_PRIVATE_PUBLIC_KEY: &str = "-----BEGIN PRIVATE KEY----- +MHICAQEwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC +oB8wHQYKKoZIhvcNAQkJFDEPDA1DdXJkbGUgQ2hhaXJzgSEAGb9ECWmEzf6FQbrB +Z9w7lshQhqowtrbLDFw4rXAxZuE= +-----END PRIVATE KEY-----"; + + #[test] + fn test_decode_rfc8410_ed25519_private_public_key() { + env_logger::try_init().unwrap_or(()); + assert!( + decode_secret_key(RFC8410_ED25519_PRIVATE_PUBLIC_KEY, None) + .unwrap() + .algorithm() + == ssh_key::Algorithm::Ed25519, + ); + // We can't encode attributes, skip test_decode_encode_symmetry. + } + #[test] - #[cfg(feature = "openssl")] fn test_decode_rsa_secret_key() { env_logger::try_init().unwrap_or(()); decode_secret_key(RSA_KEY, None).unwrap(); } #[test] - #[cfg(feature = "openssl")] + fn test_decode_openssh_p256_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 256 -m rfc4716 -f $file + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS +1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQQ/i+HCsmZZPy0JhtT64vW7EmeA1DeA +M/VnPq3vAhu+xooJ7IMMK3lUHlBDosyvA2enNbCWyvNQc25dVt4oh9RhAAAAqHG7WMFxu1 +jBAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBD+L4cKyZlk/LQmG +1Pri9bsSZ4DUN4Az9Wc+re8CG77GignsgwwreVQeUEOizK8DZ6c1sJbK81Bzbl1W3iiH1G +EAAAAgLAmXR6IlN0SdiD6o8qr+vUr0mXLbajs/m0UlegElOmoAAAANcm9iZXJ0QGJic2Rl +dgECAw== +-----END OPENSSH PRIVATE KEY----- +"; + assert!( + decode_secret_key(key, None).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP256 + }, + ); + } + + #[test] + fn test_decode_openssh_p384_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 384 -m rfc4716 -f $file + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS +1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQTkLnKPk/1NZD9mQ8XoebD7ASv9/svh +5jO75HF7RYAqKK3fl5wsHe4VTJAOT3qH841yTcK79l0dwhHhHeg60byL7F9xOEzr2kqGeY +Uwrl7fVaL7hfHzt6z+sG8smSQ3tF8AAADYHjjBch44wXIAAAATZWNkc2Etc2hhMi1uaXN0 +cDM4NAAAAAhuaXN0cDM4NAAAAGEE5C5yj5P9TWQ/ZkPF6Hmw+wEr/f7L4eYzu+Rxe0WAKi +it35ecLB3uFUyQDk96h/ONck3Cu/ZdHcIR4R3oOtG8i+xfcThM69pKhnmFMK5e31Wi+4Xx +87es/rBvLJkkN7RfAAAAMFzt6053dxaQT0Ta/CGfZna0nibHzxa55zgBmje/Ho3QDNlBCH +Ylv0h4Wyzto8NfLQAAAA1yb2JlcnRAYmJzZGV2AQID +-----END OPENSSH PRIVATE KEY----- +"; + assert!( + decode_secret_key(key, None).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP384 + }, + ); + } + + #[test] + fn test_decode_openssh_p521_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 521 -m rfc4716 -f $file + let key = "-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS +1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQA7a9awmFeDjzYiuUOwMfXkKTevfQI +iGlduu8BkjBOWXpffJpKsdTyJI/xI05l34OvqfCCkPUcfFWHK+LVRGahMBgBcGB9ZZOEEq +iKNIT6C9WcJTGDqcBSzQ2yTSOxPXfUmVTr4D76vbYu5bjd9aBKx8HdfMvPeo0WD0ds/LjX +LdJoDXcAAAEQ9fxlIfX8ZSEAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ +AAAIUEAO2vWsJhXg482IrlDsDH15Ck3r30CIhpXbrvAZIwTll6X3yaSrHU8iSP8SNOZd+D +r6nwgpD1HHxVhyvi1URmoTAYAXBgfWWThBKoijSE+gvVnCUxg6nAUs0Nsk0jsT131JlU6+ +A++r22LuW43fWgSsfB3XzLz3qNFg9HbPy41y3SaA13AAAAQgH4DaftY0e/KsN695VJ06wy +Ve0k2ddxoEsSE15H4lgNHM2iuYKzIqZJOReHRCTff6QGgMYPDqDfFfL1Hc1Ntql0pwAAAA +1yb2JlcnRAYmJzZGV2AQIDBAU= +-----END OPENSSH PRIVATE KEY----- +"; + assert!( + decode_secret_key(key, None).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP521 + }, + ); + } + + #[test] fn test_fingerprint() { let key = parse_public_key_base64( "AAAAC3NzaC1lZDI1NTE5AAAAILagOJFgwaMNhBWQINinKOXmqS4Gh5NgxgriXwdOoINJ", ) .unwrap(); assert_eq!( - key.fingerprint(), - "ldyiXa1JQakitNU5tErauu8DvWQ1dZ7aXu+rm7KQuog" + format!("{}", key.fingerprint(ssh_key::HashAlg::Sha256)), + "SHA256:ldyiXa1JQakitNU5tErauu8DvWQ1dZ7aXu+rm7KQuog" ); } #[test] - fn test_check_known_hosts() { + fn test_parse_p256_public_key() { env_logger::try_init().unwrap_or(()); - let dir = tempdir::TempDir::new("russh").unwrap(); - let path = dir.path().join("known_hosts"); - { - let mut f = File::create(&path).unwrap(); - f.write(b"[localhost]:13265 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ\n#pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\npijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); - } + let key = "AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBMxBTpMIGvo7CnordO7wP0QQRqpBwUjOLl4eMhfucfE1sjTYyK5wmTl1UqoSDS1PtRVTBdl+0+9pquFb46U7fwg="; - // Valid key, non-standard port. - let host = "localhost"; - let port = 13265; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + assert!( + parse_public_key_base64(key).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP256 + }, + ); + } - // Valid key, several hosts, port 22 - let host = "pijul.org"; - let port = 22; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + #[test] + fn test_parse_p384_public_key() { + env_logger::try_init().unwrap_or(()); + let key = "AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBVFgxJxpCaAALZG/S5BHT8/IUQ5mfuKaj7Av9g7Jw59fBEGHfPBz1wFtHGYw5bdLmfVZTIDfogDid5zqJeAKr1AcD06DKTXDzd2EpUjqeLfQ5b3erHuX758fgu/pSDGRA=="; - // Now with the key in a comment above, check that it's not recognized - let host = "pijul.org"; - let port = 22; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).is_err()); + assert!( + parse_public_key_base64(key).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP384 + } + ); + } + + #[test] + fn test_parse_p521_public_key() { + env_logger::try_init().unwrap_or(()); + let key = "AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBAAQepXEpOrzlX22r4E5zEHjhHWeZUe//zaevTanOWRBnnaCGWJFGCdjeAbNOuAmLtXc+HZdJTCZGREeSLSrpJa71QDCgZl0N7DkDUanCpHZJe/DCK6qwtHYbEMn28iLMlGCOrCIa060EyJHbp1xcJx4I1SKj/f/fm3DhhID/do6zyf8Cg=="; + + assert!( + parse_public_key_base64(key).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP521 + } + ); } #[test] - #[cfg(feature = "openssl")] fn test_srhb() { env_logger::try_init().unwrap_or(()); let key = "AAAAB3NzaC1yc2EAAAADAQABAAACAQC0Xtz3tSNgbUQAXem4d+d6hMx7S8Nwm/DOO2AWyWCru+n/+jQ7wz2b5+3oG2+7GbWZNGj8HCc6wJSA3jUsgv1N6PImIWclD14qvoqY3Dea1J0CJgXnnM1xKzBz9C6pDHGvdtySg+yzEO41Xt4u7HFn4Zx5SGuI2NBsF5mtMLZXSi33jCIWVIkrJVd7sZaY8jiqeVZBB/UvkLPWewGVuSXZHT84pNw4+S0Rh6P6zdNutK+JbeuO+5Bav4h9iw4t2sdRkEiWg/AdMoSKmo97Gigq2mKdW12ivnXxz3VfxrCgYJj9WwaUUWSfnAju5SiNly0cTEAN4dJ7yB0mfLKope1kRhPsNaOuUmMUqlu/hBDM/luOCzNjyVJ+0LLB7SV5vOiV7xkVd4KbEGKou8eeCR3yjFazUe/D1pjYPssPL8cJhTSuMc+/UC9zD8yeEZhB9V+vW4NMUR+lh5+XeOzenl65lWYd/nBZXLBbpUMf1AOfbz65xluwCxr2D2lj46iApSIpvE63i3LzFkbGl9GdUiuZJLMFJzOWdhGGc97cB5OVyf8umZLqMHjaImxHEHrnPh1MOVpv87HYJtSBEsN4/omINCMZrk++CRYAIRKRpPKFWV7NQHcvw3m7XLR3KaTYe+0/MINIZwGdou9fLUU3zSd521vDjA/weasH0CyDHq7sZw=="; @@ -567,7 +510,6 @@ QR+u0AypRPmzHnOPAAAAEXJvb3RAMTQwOTExNTQ5NDBkAQ== } #[test] - #[cfg(feature = "openssl")] fn test_nikao() { env_logger::try_init().unwrap_or(()); let key = "-----BEGIN RSA PRIVATE KEY----- @@ -601,81 +543,144 @@ QaChXiDsryJZwsRnruvMRX9nedtqHrgnIsJLTXjppIhGhq5Kg4RQfOU= decode_secret_key(key, None).unwrap(); } - #[cfg(feature = "openssl")] - pub const PKCS8_RSA: &str = "-----BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEAwBGetHjW+3bDQpVktdemnk7JXgu1NBWUM+ysifYLDBvJ9ttX -GNZSyQKA4v/dNr0FhAJ8I9BuOTjYCy1YfKylhl5D/DiSSXFPsQzERMmGgAlYvU2U -+FTxpBC11EZg69CPVMKKevfoUD+PZA5zB7Hc1dXFfwqFc5249SdbAwD39VTbrOUI -WECvWZs6/ucQxHHXP2O9qxWqhzb/ddOnqsDHUNoeceiNiCf2anNymovrIMjAqq1R -t2UP3f06/Zt7Jx5AxKqS4seFkaDlMAK8JkEDuMDOdKI36raHkKanfx8CnGMSNjFQ -QtvnpD8VSGkDTJN3Qs14vj2wvS477BQXkBKN1QIDAQABAoIBABb6xLMw9f+2ENyJ -hTggagXsxTjkS7TElCu2OFp1PpMfTAWl7oDBO7xi+UqvdCcVbHCD35hlWpqsC2Ui -8sBP46n040ts9UumK/Ox5FWaiuYMuDpF6vnfJ94KRcb0+KmeFVf9wpW9zWS0hhJh -jC+yfwpyfiOZ/ad8imGCaOguGHyYiiwbRf381T/1FlaOGSae88h+O8SKTG1Oahq4 -0HZ/KBQf9pij0mfVQhYBzsNu2JsHNx9+DwJkrXT7K9SHBpiBAKisTTCnQmS89GtE -6J2+bq96WgugiM7X6OPnmBmE/q1TgV18OhT+rlvvNi5/n8Z1ag5Xlg1Rtq/bxByP -CeIVHsECgYEA9dX+LQdv/Mg/VGIos2LbpJUhJDj0XWnTRq9Kk2tVzr+9aL5VikEb -09UPIEa2ToL6LjlkDOnyqIMd/WY1W0+9Zf1ttg43S/6Rvv1W8YQde0Nc7QTcuZ1K -9jSSP9hzsa3KZtx0fCtvVHm+ac9fP6u80tqumbiD2F0cnCZcSxOb4+UCgYEAyAKJ -70nNKegH4rTCStAqR7WGAsdPE3hBsC814jguplCpb4TwID+U78Xxu0DQF8WtVJ10 -SJuR0R2q4L9uYWpo0MxdawSK5s9Am27MtJL0mkFQX0QiM7hSZ3oqimsdUdXwxCGg -oktxCUUHDIPJNVd4Xjg0JTh4UZT6WK9hl1zLQzECgYEAiZRCFGc2KCzVLF9m0cXA -kGIZUxFAyMqBv+w3+zq1oegyk1z5uE7pyOpS9cg9HME2TAo4UPXYpLAEZ5z8vWZp -45sp/BoGnlQQsudK8gzzBtnTNp5i/MnnetQ/CNYVIVnWjSxRUHBqdMdRZhv0/Uga -e5KA5myZ9MtfSJA7VJTbyHUCgYBCcS13M1IXaMAt3JRqm+pftfqVs7YeJqXTrGs/ -AiDlGQigRk4quFR2rpAV/3rhWsawxDmb4So4iJ16Wb2GWP4G1sz1vyWRdSnmOJGC -LwtYrvfPHegqvEGLpHa7UsgDpol77hvZriwXwzmLO8A8mxkeW5dfAfpeR5o+mcxW -pvnTEQKBgQCKx6Ln0ku6jDyuDzA9xV2/PET5D75X61R2yhdxi8zurY/5Qon3OWzk -jn/nHT3AZghGngOnzyv9wPMKt9BTHyTB6DlB6bRVLDkmNqZh5Wi8U1/IjyNYI0t2 -xV/JrzLAwPoKk3bkqys3bUmgo6DxVC/6RmMwPQ0rmpw78kOgEej90g== ------END RSA PRIVATE KEY----- + #[test] + fn test_decode_pkcs8_rsa_secret_key() { + // Generated using: ssh-keygen -t rsa -b 1024 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDTwWfiCKHw/1F6 +pvm6hZpFSjCVSu4Pp0/M4xT9Cec1+2uj/6uEE9Vh/UhlerkxVbrW/YaqjnlAiemZ +0RGN+sq7b8LxsgvOAo7gdBv13TLkKxNFiRbSy8S257uA9/K7G4Uw+NW22zoLSKCp +pdJOFzaYMIT/UX9EOq9hIIn4bS4nXJ4V5+aHBtMddHHDQPEDHBHuifpP2L4Wopzu +WoQoVtN9cwHSLh0Bd7uT+X9useIJrFzcsxVXwD2WGfR59Ue3rxRu6JqC46Klf55R +5NQ8OQ+7NHXjW5HO076W1GXcnhGKT5CGjglTdk5XxQkNZsz72cHu7RDaADdWAWnE +hSyH7flrAgMBAAECggEAbFdpCjn2eTJ4grOJ1AflTYxO3SOQN8wXxTFuHKUDehgg +E7GNFK99HnyTnPA0bmx5guQGEZ+BpCarsXpJbAYj0dC1wimhZo7igS6G272H+zua +yZoBZmrBQ/++bJbvxxGmjM7TsZHq2bkYEpR3zGKOGUHB2kvdPJB2CNC4JrXdxl7q +djjsr5f/SreDmHqcNBe1LcyWLSsuKTfwTKhsE1qEe6QA2uOpUuFrsdPoeYrfgapu +sK6qnpxvOTJHCN/9jjetrP2fGl78FMBYfXzjAyKSKzLvzOwMAmcHxy50RgUvezx7 +A1RwMpB7VoV0MOpcAjlQ1T7YDH9avdPMzp0EZ24y+QKBgQD/MxDJjHu33w13MnIg +R4BrgXvrgL89Zde5tML2/U9C2LRvFjbBvgnYdqLsuqxDxGY/8XerrAkubi7Fx7QI +m2uvTOZF915UT/64T35zk8nAAFhzicCosVCnBEySvdwaaBKoj/ywemGrwoyprgFe +r8LGSo42uJi0zNf5IxmVzrDlRwKBgQDUa3P/+GxgpUYnmlt63/7sII6HDssdTHa9 +x5uPy8/2ackNR7FruEAJR1jz6akvKnvtbCBeRxLeOFwsseFta8rb2vks7a/3I8ph +gJlbw5Bttpc+QsNgC61TdSKVsfWWae+YT77cfGPM4RaLlxRnccW1/HZjP2AMiDYG +WCiluO+svQKBgQC3a/yk4FQL1EXZZmigysOCgY6Ptfm+J3TmBQYcf/R4F0mYjl7M +4coxyxNPEty92Gulieh5ey0eMhNsFB1SEmNTm/HmV+V0tApgbsJ0T8SyO41Xfar7 +lHZjlLN0xQFt+V9vyA3Wyh9pVGvFiUtywuE7pFqS+hrH2HNindfF1MlQAQKBgQDF +YxBIxKzY5duaA2qMdMcq3lnzEIEXua0BTxGz/n1CCizkZUFtyqnetWjoRrGK/Zxp +FDfDw6G50397nNPQXQEFaaZv5HLGYYC3N8vKJKD6AljqZxmsD03BprA7kEGYwtn8 +m+XMdt46TNMpZXt1YJiLMo1ETmjPXGdvX85tqLs2tQKBgQDCbwd+OBzSiic3IQlD +E/OHAXH6HNHmUL3VD5IiRh4At2VAIl8JsmafUvvbtr5dfT3PA8HB6sDG4iXQsBbR +oTSAo/DtIWt1SllGx6MvcPqL1hp1UWfoIGTnE3unHtgPId+DnjMbTcuZOuGl7evf +abw8VeY2goORjpBXsfydBETbgQ== +-----END PRIVATE KEY----- "; + assert!(decode_secret_key(key, None).unwrap().algorithm().is_rsa()); + test_decode_encode_symmetry(key); + } #[test] - #[cfg(feature = "openssl")] - fn test_loewenheim() -> Result<(), Error> { - env_logger::try_init().unwrap_or(()); - let key = "-----BEGIN RSA PRIVATE KEY----- -Proc-Type: 4,ENCRYPTED -DEK-Info: AES-128-CBC,80E4FCAD049EE007CCE1C65D52CDB87A - -ZKBKtex8+DA/d08TTPp4vY8RV+r+1nUC1La+r0dSiXsfunRNDPcYhHbyA/Fdr9kQ -+d1/E3cEb0k2nq7xYyMzy8hpNp/uHu7UfllGdaBusiPjHR+feg6AQfbM0FWpdGzo -9l/Vho5Ocw8abQq1Q9aPW5QQXBURC7HtCQXbpuYjUAQBeea1LzPCw6UIF80GUUkY -1AycXxVfx1AeURAKTZR4hsxC5pqI4yhAvVNXxP+tTTa9NE8lOP0yqVNurfIqyAnp -5ELMwNdHXZyUcT+EH5PsC69ocQgEZqLs0chvke62woMOjeSpsW5cIjGohW9lOD1f -nJkECVZ50kE0SDvcL4Y338tHwMt7wdwdj1dkAWSUjAJT4ShjqV/TzaLAiNAyRxLl -cm3mAccaFIIBZG/bPLGI0B5+mf9VExXGJrbGlvURhtE3nwmjLg1vT8lVfqbyL3a+ -0tFvmDYn71L97t/3hcD2tVnKLv9g8+/OCsUAk3+/0eS7D6GpmlOMRHdLLUHc4SOm -bIDT/dE6MjsCSm7n/JkTb8P+Ta1Hp94dUnX4pfjzZ+O8V1H8wv7QW5KsuJhJ8cn4 -eS3BEgNH1I4FCCjLsZdWve9ehV3/19WXh+BF4WXFq9b3plmfJgTiZslvjy4dgThm -OhEK44+fN1UhzguofxTR4Maz7lcehQxGAxp14hf1EnaAEt3LVjEPEShgK5dx1Ftu -LWFz9nR4vZcMsaiszElrevqMhPQHXY7cnWqBenkMfkdcQDoZjKvV86K98kBIDMu+ -kf855vqRF8b2n/6HPdm3eqFh/F410nSB0bBSglUfyOZH1nS+cs79RQZEF9fNUmpH -EPQtQ/PALohicj9Vh7rRaMKpsORdC8/Ahh20s01xL6siZ334ka3BLYT94UG796/C -4K1S2kPdUP8POJ2HhaK2l6qaG8tcEX7HbwwZeKiEHVNvWuIGQO9TiDONLycp9x4y -kNM3sv2pI7vEhs7d2NapWgNha1RcTSv0CQ6Th/qhGo73LBpVmKwombVImHAyMGAE -aVF32OycVd9c9tDgW5KdhWedbeaxD6qkSs0no71083kYIS7c6iC1R3ZeufEkMhmx -dwrciWTJ+ZAk6rS975onKz6mo/4PytcCY7Df/6xUxHF3iJCnuK8hNpLdJcdOiqEK -zj/d5YGyw3J2r+NrlV1gs3FyvR3eMCWWH2gpIQISBpnEANY40PxA/ogH+nCUvI/O -n8m437ZeLTg6lnPqsE4nlk2hUEwRdy/SVaQURbn7YlcYIt0e81r5sBXb4MXkLrf0 -XRWmpSggdcaaMuXi7nVSdkgCMjGP7epS7HsfP46OrTtJLHn5LxvdOEaW53nPOVQg -/PlVfDbwWl8adE3i3PDQOw9jhYXnYS3sv4R8M8y2GYEXbINrTJyUGrlNggKFS6oh -Hjgt0gsM2N/D8vBrQwnRtyymRnFd4dXFEYKAyt+vk0sa36eLfl0z6bWzIchkJbdu -raMODVc+NiJE0Qe6bwAi4HSpJ0qw2lKwVHYB8cdnNVv13acApod326/9itdbb3lt -KJaj7gc0n6gmKY6r0/Ddufy1JZ6eihBCSJ64RARBXeg2rZpyT+xxhMEZLK5meOeR ------END RSA PRIVATE KEY----- + fn test_decode_pkcs8_p256_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 256 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgE0C7/pyJDcZTAgWo +ydj6EE8QkZ91jtGoGmdYAVd7LaqhRANCAATWkGOof7R/PAUuOr2+ZPUgB8rGVvgr +qa92U3p4fkJToKXku5eq/32OBj23YMtz76jO3yfMbtG3l1JWLowPA8tV +-----END PRIVATE KEY----- "; - let key = decode_secret_key(key, Some("passphrase")).unwrap(); - let public = key.clone_public_key()?; - let buf = b"blabla"; - let sig = key.sign_detached(buf).unwrap(); - assert!(public.verify_detached(buf, sig.as_ref())); + assert!( + decode_secret_key(key, None).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP256 + }, + ); + test_decode_encode_symmetry(key); + } + + #[test] + fn test_decode_pkcs8_p384_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 384 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDCaqAL30kg+T5BUOYG9 +MrzeDXiUwy9LM8qJGNXiMYou0pVjFZPZT3jAsrUQo47PLQ6hZANiAARuEHbXJBYK +9uyJj4PjT56OHjT2GqMa6i+FTG9vdLtu4OLUkXku+kOuFNjKvEI1JYBrJTpw9kSZ +CI3WfCsQvVjoC7m8qRyxuvR3Rv8gGXR1coQciIoCurLnn9zOFvXCS2Y= +-----END PRIVATE KEY----- +"; + assert!( + decode_secret_key(key, None).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP384 + }, + ); + test_decode_encode_symmetry(key); + } + + #[test] + fn test_decode_pkcs8_p521_secret_key() { + // Generated using: ssh-keygen -t ecdsa -b 521 -m pkcs8 -f $file + let key = "-----BEGIN PRIVATE KEY----- +MIHuAgEAMBAGByqGSM49AgEGBSuBBAAjBIHWMIHTAgEBBEIB1As9UBUsCiMK7Rzs +EoMgqDM/TK7y7+HgCWzw5UujXvSXCzYCeBgfJszn7dVoJE9G/1ejmpnVTnypdKEu +iIvd4LyhgYkDgYYABAADBCrg7hkomJbCsPMuMcq68ulmo/6Tv8BDS13F8T14v5RN +/0iT/+nwp6CnbBFewMI2TOh/UZNyPpQ8wOFNn9zBmAFCMzkQibnSWK0hrRstY5LT +iaOYDwInbFDsHu8j3TGs29KxyVXMexeV6ROQyXzjVC/quT1R5cOQ7EadE4HvaWhT +Ow== +-----END PRIVATE KEY----- +"; + assert!( + decode_secret_key(key, None).unwrap().algorithm() + == ssh_key::Algorithm::Ecdsa { + curve: ssh_key::EcdsaCurve::NistP521 + }, + ); + test_decode_encode_symmetry(key); + } + + #[test] + #[cfg(feature = "legacy-ed25519-pkcs8-parser")] + fn test_decode_pkcs8_ed25519_generated_by_russh_0_43() -> Result<(), crate::keys::Error> { + // Generated by russh 0.43 + let key = "-----BEGIN PRIVATE KEY----- +MHMCAQEwBQYDK2VwBEIEQBHw4cXPpGgA+KdvPF5gxrzML+oa3yQk0JzIbWvmqM5H30RyBF8GrOWz +p77UAd3O4PgYzzFcUc79g8yKtbKhzJGhIwMhAN9EcgRfBqzls6e+1AHdzuD4GM8xXFHO/YPMirWy +ocyR + +-----END PRIVATE KEY----- +"; + + assert!(decode_secret_key(key, None)?.algorithm() == ssh_key::Algorithm::Ed25519,); + + let k = decode_secret_key(key, None)?; + let inner = k.key_data().ed25519().unwrap(); + + assert_eq!( + &inner.private.to_bytes(), + &[ + 17, 240, 225, 197, 207, 164, 104, 0, 248, 167, 111, 60, 94, 96, 198, 188, 204, 47, + 234, 26, 223, 36, 36, 208, 156, 200, 109, 107, 230, 168, 206, 71 + ] + ); + Ok(()) } + fn test_decode_encode_symmetry(key: &str) { + let original_key_bytes = data_encoding::BASE64_MIME + .decode( + key.lines() + .filter(|line| !line.starts_with("-----")) + .collect::>() + .join("") + .as_bytes(), + ) + .unwrap(); + let decoded_key = decode_secret_key(key, None).unwrap(); + let encoded_key_bytes = pkcs8::encode_pkcs8(&decoded_key).unwrap(); + assert_eq!(original_key_bytes, encoded_key_bytes); + } + #[test] - #[cfg(feature = "openssl")] fn test_o01eg() { env_logger::try_init().unwrap_or(()); @@ -712,15 +717,43 @@ br8gXU8KyiY9sZVbmplRPF+ar462zcI2kt0a18mr0vbrdqp2eMjb37QDbVBJ+rPE "; decode_secret_key(key, Some("12345")).unwrap(); } + + pub const PKCS8_RSA: &str = "-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAwBGetHjW+3bDQpVktdemnk7JXgu1NBWUM+ysifYLDBvJ9ttX +GNZSyQKA4v/dNr0FhAJ8I9BuOTjYCy1YfKylhl5D/DiSSXFPsQzERMmGgAlYvU2U ++FTxpBC11EZg69CPVMKKevfoUD+PZA5zB7Hc1dXFfwqFc5249SdbAwD39VTbrOUI +WECvWZs6/ucQxHHXP2O9qxWqhzb/ddOnqsDHUNoeceiNiCf2anNymovrIMjAqq1R +t2UP3f06/Zt7Jx5AxKqS4seFkaDlMAK8JkEDuMDOdKI36raHkKanfx8CnGMSNjFQ +QtvnpD8VSGkDTJN3Qs14vj2wvS477BQXkBKN1QIDAQABAoIBABb6xLMw9f+2ENyJ +hTggagXsxTjkS7TElCu2OFp1PpMfTAWl7oDBO7xi+UqvdCcVbHCD35hlWpqsC2Ui +8sBP46n040ts9UumK/Ox5FWaiuYMuDpF6vnfJ94KRcb0+KmeFVf9wpW9zWS0hhJh +jC+yfwpyfiOZ/ad8imGCaOguGHyYiiwbRf381T/1FlaOGSae88h+O8SKTG1Oahq4 +0HZ/KBQf9pij0mfVQhYBzsNu2JsHNx9+DwJkrXT7K9SHBpiBAKisTTCnQmS89GtE +6J2+bq96WgugiM7X6OPnmBmE/q1TgV18OhT+rlvvNi5/n8Z1ag5Xlg1Rtq/bxByP +CeIVHsECgYEA9dX+LQdv/Mg/VGIos2LbpJUhJDj0XWnTRq9Kk2tVzr+9aL5VikEb +09UPIEa2ToL6LjlkDOnyqIMd/WY1W0+9Zf1ttg43S/6Rvv1W8YQde0Nc7QTcuZ1K +9jSSP9hzsa3KZtx0fCtvVHm+ac9fP6u80tqumbiD2F0cnCZcSxOb4+UCgYEAyAKJ +70nNKegH4rTCStAqR7WGAsdPE3hBsC814jguplCpb4TwID+U78Xxu0DQF8WtVJ10 +SJuR0R2q4L9uYWpo0MxdawSK5s9Am27MtJL0mkFQX0QiM7hSZ3oqimsdUdXwxCGg +oktxCUUHDIPJNVd4Xjg0JTh4UZT6WK9hl1zLQzECgYEAiZRCFGc2KCzVLF9m0cXA +kGIZUxFAyMqBv+w3+zq1oegyk1z5uE7pyOpS9cg9HME2TAo4UPXYpLAEZ5z8vWZp +45sp/BoGnlQQsudK8gzzBtnTNp5i/MnnetQ/CNYVIVnWjSxRUHBqdMdRZhv0/Uga +e5KA5myZ9MtfSJA7VJTbyHUCgYBCcS13M1IXaMAt3JRqm+pftfqVs7YeJqXTrGs/ +AiDlGQigRk4quFR2rpAV/3rhWsawxDmb4So4iJ16Wb2GWP4G1sz1vyWRdSnmOJGC +LwtYrvfPHegqvEGLpHa7UsgDpol77hvZriwXwzmLO8A8mxkeW5dfAfpeR5o+mcxW +pvnTEQKBgQCKx6Ln0ku6jDyuDzA9xV2/PET5D75X61R2yhdxi8zurY/5Qon3OWzk +jn/nHT3AZghGngOnzyv9wPMKt9BTHyTB6DlB6bRVLDkmNqZh5Wi8U1/IjyNYI0t2 +xV/JrzLAwPoKk3bkqys3bUmgo6DxVC/6RmMwPQ0rmpw78kOgEej90g== +-----END RSA PRIVATE KEY----- +"; + #[test] - #[cfg(feature = "openssl")] fn test_pkcs8() { env_logger::try_init().unwrap_or(()); println!("test"); decode_secret_key(PKCS8_RSA, Some("blabla")).unwrap(); } - #[cfg(feature = "openssl")] const PKCS8_ENCRYPTED: &str = "-----BEGIN ENCRYPTED PRIVATE KEY----- MIIFLTBXBgkqhkiG9w0BBQ0wSjApBgkqhkiG9w0BBQwwHAQITo1O0b8YrS0CAggA MAwGCCqGSIb3DQIJBQAwHQYJYIZIAWUDBAEqBBBtLH4T1KOfo1GGr7salhR8BIIE @@ -753,10 +786,8 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux -----END ENCRYPTED PRIVATE KEY-----"; #[test] - #[cfg(feature = "openssl")] fn test_gpg() { env_logger::try_init().unwrap_or(()); - let algo = [115, 115, 104, 45, 114, 115, 97]; let key = [ 0, 0, 0, 7, 115, 115, 104, 45, 114, 115, 97, 0, 0, 0, 3, 1, 0, 1, 0, 0, 1, 129, 0, 163, 72, 59, 242, 4, 248, 139, 217, 57, 126, 18, 195, 170, 3, 94, 154, 9, 150, 89, 171, 236, @@ -781,12 +812,10 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux 117, 254, 51, 45, 93, 184, 80, 225, 158, 29, 76, 38, 69, 72, 71, 76, 50, 191, 210, 95, 152, 175, 26, 207, 91, 7, ]; - debug!("algo = {:?}", std::str::from_utf8(&algo)); - key::PublicKey::parse(&algo, &key).unwrap(); + ssh_key::PublicKey::decode(&key).unwrap(); } #[test] - #[cfg(feature = "openssl")] fn test_pkcs8_encrypted() { env_logger::try_init().unwrap_or(()); println!("test"); @@ -794,85 +823,90 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux } #[cfg(unix)] - fn test_client_agent(key: key::KeyPair) { + async fn test_client_agent(key: PrivateKey) -> Result<(), Box> { env_logger::try_init().unwrap_or(()); - use std::process::{Command, Stdio}; - let dir = tempdir::TempDir::new("russh").unwrap(); + use std::process::Stdio; + + let dir = tempfile::tempdir()?; let agent_path = dir.path().join("agent"); - let mut agent = Command::new("ssh-agent") + let mut agent = tokio::process::Command::new("ssh-agent") .arg("-a") .arg(&agent_path) .arg("-D") .stdout(Stdio::null()) .stderr(Stdio::null()) - .spawn() - .expect("failed to execute process"); - std::thread::sleep(std::time::Duration::from_millis(10)); - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async move { - let public = key.clone_public_key()?; - let stream = tokio::net::UnixStream::connect(&agent_path).await?; - let mut client = agent::client::AgentClient::connect(stream); - client.add_identity(&key, &[]).await?; - client.request_identities().await?; - let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); - let len = buf.len(); - let (_, buf) = client.sign_request(&public, buf).await; - let buf = buf?; - let (a, b) = buf.split_at(len); - match key { - key::KeyPair::Ed25519 { .. } => { - let sig = &b[b.len() - 64..]; - assert!(public.verify_detached(a, sig)); - } - #[cfg(feature = "openssl")] - _ => {} + .spawn()?; + + // Wait for the socket to be created + while agent_path.canonicalize().is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + let public = key.public_key(); + let stream = tokio::net::UnixStream::connect(&agent_path).await?; + let mut client = agent::client::AgentClient::connect(stream); + client.add_identity(&key, &[]).await?; + client.request_identities().await?; + let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); + let len = buf.len(); + let buf = client.sign_request(public, None, buf).await.unwrap(); + let (a, b) = buf.split_at(len); + + match key.public_key().key_data() { + ssh_key::public::KeyData::Ed25519 { .. } => { + let sig = &b[b.len() - 64..]; + let sig = ssh_key::Signature::new(key.algorithm(), sig)?; + use signature::Verifier; + assert!(Verifier::verify(public, a, &sig).is_ok()); } - Ok::<(), Error>(()) - }) - .unwrap(); - agent.kill().unwrap(); - agent.wait().unwrap(); + ssh_key::public::KeyData::Ecdsa { .. } => {} + _ => {} + } + + agent.kill().await?; + agent.wait().await?; + + Ok(()) } - #[test] + #[tokio::test] #[cfg(unix)] - fn test_client_agent_ed25519() { + async fn test_client_agent_ed25519() { let key = decode_secret_key(ED25519_KEY, Some("blabla")).unwrap(); - test_client_agent(key) + test_client_agent(key).await.expect("ssh-agent test failed") } - #[test] - #[cfg(feature = "openssl")] - fn test_client_agent_rsa() { + #[tokio::test] + #[cfg(unix)] + async fn test_client_agent_rsa() { let key = decode_secret_key(PKCS8_ENCRYPTED, Some("blabla")).unwrap(); - test_client_agent(key) + test_client_agent(key).await.expect("ssh-agent test failed") } - #[test] - #[cfg(feature = "openssl")] - fn test_client_agent_openssh_rsa() { + #[tokio::test] + #[cfg(unix)] + async fn test_client_agent_openssh_rsa() { let key = decode_secret_key(RSA_KEY, None).unwrap(); - test_client_agent(key) + test_client_agent(key).await.expect("ssh-agent test failed") } #[test] #[cfg(unix)] - #[cfg(feature = "openssl")] fn test_agent() { env_logger::try_init().unwrap_or(()); - let dir = tempdir::TempDir::new("russh").unwrap(); + let dir = tempfile::tempdir().unwrap(); let agent_path = dir.path().join("agent"); let core = tokio::runtime::Runtime::new().unwrap(); use agent; + use signature::Verifier; #[derive(Clone)] struct X {} impl agent::server::Agent for X { fn confirm( self, - _: std::sync::Arc, + _: std::sync::Arc, ) -> Box + Send + Unpin> { Box::new(futures::future::ready((self, true))) } @@ -891,28 +925,24 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux }); let key = decode_secret_key(PKCS8_ENCRYPTED, Some("blabla")).unwrap(); core.block_on(async move { - let public = key.clone_public_key()?; - let stream = tokio::net::UnixStream::connect(&agent_path).await?; + let public = key.public_key(); + let stream = tokio::net::UnixStream::connect(&agent_path).await.unwrap(); let mut client = agent::client::AgentClient::connect(stream); client .add_identity(&key, &[agent::Constraint::KeyLifetime { seconds: 60 }]) - .await?; - client.request_identities().await?; + .await + .unwrap(); + client.request_identities().await.unwrap(); let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); let len = buf.len(); - let (_, buf) = client.sign_request(&public, buf).await; - let buf = buf?; + let buf = client.sign_request(public, None, buf).await.unwrap(); let (a, b) = buf.split_at(len); - match key { - key::KeyPair::Ed25519 { .. } => { - let sig = &b[b.len() - 64..]; - assert!(public.verify_detached(a, sig)); - } - _ => {} + if let ssh_key::public::KeyData::Ed25519 { .. } = public.key_data() { + let sig = &b[b.len() - 64..]; + let sig = ssh_key::Signature::new(key.algorithm(), sig).unwrap(); + assert!(Verifier::verify(public, a, &sig).is_ok()); } - Ok::<(), Error>(()) }) - .unwrap() } #[cfg(unix)] diff --git a/russh/src/lib.rs b/russh/src/lib.rs index cf4cf8a9..4b8fa76e 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -30,11 +30,6 @@ //! * [Writing SSH clients - the `russh::client` module](client) //! * [Writing SSH servers - the `russh::server` module](server) //! -//! # Important crate features -//! -//! * RSA key support is gated behind the `openssl` feature (disabled by default). -//! * Enabling that and disabling the `rs-crypto` feature (enabled by default) will leave you with a very basic, but pure-OpenSSL RSA+AES cipherset. -//! //! # Using non-socket IO / writing tunnels //! //! The easy way to implement SSH tunnels, like `ProxyCommand` for @@ -50,7 +45,7 @@ //! relatively simple: clients and servers open *channels*, which are //! just integers used to handle multiple requests in parallel in a //! single connection. Once a client has obtained a `ChannelId` by -//! calling one the many `channel_open_…` methods of +//! calling one of the many `channel_open_…` methods of //! `client::Connection`, the client may send exec requests and data //! to the server. //! @@ -65,10 +60,7 @@ //! # Design principles //! //! The main goal of this library is conciseness, and reduced size and -//! readability of the library's code. Moreover, this library is split -//! between Russh, which implements the main logic of SSH clients -//! and servers, and Russh-keys, which implements calls to -//! cryptographic primitives. +//! readability of the library's code. //! //! One non-goal is to implement all possible cryptographic algorithms //! published since the initial release of SSH. Technical debt is @@ -94,23 +86,34 @@ //! messages sent through a `server::Handle` are processed when there //! is no incoming packet to read. +use std::convert::TryFrom; use std::fmt::{Debug, Display, Formatter}; +use std::future::{Future, Pending}; -use thiserror::Error; +use futures::future::Either as EitherFuture; +use log::{debug, warn}; use parsing::ChannelOpenConfirmation; pub use russh_cryptovec::CryptoVec; +use ssh_encoding::{Decode, Encode}; +use thiserror::Error; + +#[cfg(test)] +mod tests; mod auth; +mod cert; /// Cipher names pub mod cipher; +/// Compression algorithm names +pub mod compression; /// Key exchange algorithm names pub mod kex; /// MAC algorithm names pub mod mac; -mod compression; -mod key; +pub mod keys; + mod msg; mod negotiation; mod ssh_read; @@ -123,6 +126,8 @@ mod pty; pub use pty::Pty; pub use sshbuffer::SshId; +mod helpers; + macro_rules! push_packet { ( $buffer:expr, $x:expr ) => {{ use byteorder::{BigEndian, ByteOrder}; @@ -139,20 +144,27 @@ macro_rules! push_packet { } mod channels; -pub use channels::{Channel, ChannelMsg}; - -mod channel_stream; -pub use channel_stream::ChannelStream; +pub use channels::{Channel, ChannelMsg, ChannelReadHalf, ChannelStream, ChannelWriteHalf}; mod parsing; mod session; /// Server side of this library. +#[cfg(not(target_arch = "wasm32"))] pub mod server; /// Client side of this library. pub mod client; +#[derive(Debug)] +pub enum AlgorithmKind { + Kex, + Key, + Cipher, + Compression, + Mac, +} + #[derive(Debug, Error)] pub enum Error { /// The key file could not be parsed. @@ -167,25 +179,13 @@ pub enum Error { #[error("Unknown algorithm")] UnknownAlgo, - /// No common key exchange algorithm. - #[error("No common key exchange algorithm")] - NoCommonKexAlgo, - - /// No common signature algorithm. - #[error("No common key algorithm")] - NoCommonKeyAlgo, - - /// No common cipher. - #[error("No common key cipher")] - NoCommonCipher, - - /// No common compression algorithm. - #[error("No common compression algorithm")] - NoCommonCompression, - - /// No common MAC algorithm. - #[error("No common MAC algorithm")] - NoCommonMac, + /// No common algorithm found during key exchange. + #[error("No common {kind:?} algorithm - ours: {ours:?}, theirs: {theirs:?}")] + NoCommonAlgo { + kind: AlgorithmKind, + ours: Vec, + theirs: Vec, + }, /// Invalid SSH version string. #[error("invalid SSH version string")] @@ -219,6 +219,10 @@ pub enum Error { #[error("Wrong server signature")] WrongServerSig, + /// Excessive packet size. + #[error("Bad packet size: {0}")] + PacketSize(usize), + /// Message received/sent on unopened channel. #[error("Channel not open")] WrongChannel, @@ -248,6 +252,14 @@ pub enum Error { #[error("Connection timeout")] ConnectionTimeout, + /// Keepalive timeout. + #[error("Keepalive timeout")] + KeepaliveTimeout, + + /// Inactivity timeout. + #[error("Inactivity timeout")] + InactivityTimeout, + /// Missing authentication method. #[error("No authentication method")] NoAuthMethod, @@ -261,8 +273,11 @@ pub enum Error { #[error("Failed to decrypt a packet")] DecryptionError, + #[error("The request was rejected by the other party")] + RequestDenied, + #[error(transparent)] - Keys(#[from] russh_keys::Error), + Keys(#[from] crate::keys::Error), #[error(transparent)] IO(#[from] std::io::Error), @@ -271,20 +286,52 @@ pub enum Error { Utf8(#[from] std::str::Utf8Error), #[error(transparent)] + #[cfg(feature = "flate2")] Compress(#[from] flate2::CompressError), #[error(transparent)] + #[cfg(feature = "flate2")] Decompress(#[from] flate2::DecompressError), #[error(transparent)] - Join(#[from] tokio::task::JoinError), - - #[error(transparent)] - #[cfg(feature = "openssl")] - Openssl(#[from] openssl::error::ErrorStack), + Join(#[from] russh_util::runtime::JoinError), #[error(transparent)] Elapsed(#[from] tokio::time::error::Elapsed), + + #[error("Violation detected during strict key exchange, message {message_type} at seq no {sequence_number}")] + StrictKeyExchangeViolation { + message_type: u8, + sequence_number: usize, + }, + + #[error("Signature: {0}")] + Signature(#[from] signature::Error), + + #[error("SshKey: {0}")] + SshKey(#[from] ssh_key::Error), + + #[error("SshEncoding: {0}")] + SshEncoding(#[from] ssh_encoding::Error), + + #[error("Invalid config: {0}")] + InvalidConfig(String), + + /// This error occurs when the channel is closed and there are no remaining messages in the channel buffer. + /// This is common in SSH-Agent, for example when the Agent client directly rejects an authorization request. + #[error("Unable to receive more messages from the channel")] + RecvError, +} + +pub(crate) fn strict_kex_violation(message_type: u8, sequence_number: usize) -> crate::Error { + warn!( + "strict kex violated at sequence no. {:?}, message type: {:?}", + sequence_number, message_type + ); + crate::Error::StrictKeyExchangeViolation { + message_type, + sequence_number, + } } #[derive(Debug, Error)] @@ -325,10 +372,11 @@ impl Default for Limits { } } -pub use auth::{AgentAuthError, MethodSet, Signer}; +pub use auth::{AgentAuthError, MethodKind, MethodSet, Signer}; /// A reason for disconnection. #[allow(missing_docs)] // This should be relatively self-explanatory. +#[allow(clippy::manual_non_exhaustive)] #[derive(Debug)] pub enum Disconnect { HostNotAllowedToConnect = 1, @@ -349,6 +397,31 @@ pub enum Disconnect { IllegalUserName = 15, } +impl TryFrom for Disconnect { + type Error = crate::Error; + + fn try_from(value: u32) -> Result { + Ok(match value { + 1 => Self::HostNotAllowedToConnect, + 2 => Self::ProtocolError, + 3 => Self::KeyExchangeFailed, + 4 => Self::Reserved, + 5 => Self::MACError, + 6 => Self::CompressionError, + 7 => Self::ServiceNotAvailable, + 8 => Self::ProtocolVersionNotSupported, + 9 => Self::HostKeyNotVerifiable, + 10 => Self::ConnectionLost, + 11 => Self::ByApplication, + 12 => Self::TooManyConnections, + 13 => Self::AuthCancelledByUser, + 14 => Self::NoMoreAuthMethodsAvailable, + 15 => Self::IllegalUserName, + _ => return Err(crate::Error::Inconsistent), + }) + } +} + /// The type of signals that can be sent to a remote process. If you /// plan to use custom signals, read [the /// RFC](https://tools.ietf.org/html/rfc4254#section-6.10) to @@ -390,21 +463,21 @@ impl Sig { Sig::Custom(ref c) => c, } } - fn from_name(name: &[u8]) -> Result { + fn from_name(name: &str) -> Sig { match name { - b"ABRT" => Ok(Sig::ABRT), - b"ALRM" => Ok(Sig::ALRM), - b"FPE" => Ok(Sig::FPE), - b"HUP" => Ok(Sig::HUP), - b"ILL" => Ok(Sig::ILL), - b"INT" => Ok(Sig::INT), - b"KILL" => Ok(Sig::KILL), - b"PIPE" => Ok(Sig::PIPE), - b"QUIT" => Ok(Sig::QUIT), - b"SEGV" => Ok(Sig::SEGV), - b"TERM" => Ok(Sig::TERM), - b"USR1" => Ok(Sig::USR1), - x => Ok(Sig::Custom(std::str::from_utf8(x)?.to_string())), + "ABRT" => Sig::ABRT, + "ALRM" => Sig::ALRM, + "FPE" => Sig::FPE, + "HUP" => Sig::HUP, + "ILL" => Sig::ILL, + "INT" => Sig::INT, + "KILL" => Sig::KILL, + "PIPE" => Sig::PIPE, + "QUIT" => Sig::QUIT, + "SEGV" => Sig::SEGV, + "TERM" => Sig::TERM, + "USR1" => Sig::USR1, + x => Sig::Custom(x.to_string()), } } } @@ -432,10 +505,34 @@ impl ChannelOpenFailure { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] /// The identifier of a channel. pub struct ChannelId(u32); +impl Decode for ChannelId { + type Error = ssh_encoding::Error; + + fn decode(reader: &mut impl ssh_encoding::Reader) -> Result { + Ok(Self(u32::decode(reader)?)) + } +} + +impl Encode for ChannelId { + fn encoded_len(&self) -> Result { + self.0.encoded_len() + } + + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error> { + self.0.encode(writer) + } +} + +impl From for u32 { + fn from(c: ChannelId) -> u32 { + c.0 + } +} + impl Display for ChannelId { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) @@ -453,8 +550,12 @@ pub(crate) struct ChannelParams { sender_maximum_packet_size: u32, /// Has the other side confirmed the channel? pub confirmed: bool, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] wants_reply: bool, + /// (buffer, extended stream #, data offset in buffer) pending_data: std::collections::VecDeque<(CryptoVec, Option, usize)>, + pending_eof: bool, + pending_close: bool, } impl ChannelParams { @@ -466,479 +567,13 @@ impl ChannelParams { } } -#[cfg(test)] -mod test_compress { - use std::collections::HashMap; - use std::sync::{Arc, Mutex}; - - use async_trait::async_trait; - use log::debug; - - use super::server::{Server as _, Session}; - use super::*; - use crate::server::Msg; - - #[tokio::test] - async fn compress_local_test() { - let _ = env_logger::try_init(); - - let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); - let mut config = server::Config::default(); - config.preferred = Preferred::COMPRESSED; - config.inactivity_timeout = None; // Some(std::time::Duration::from_secs(3)); - config.auth_rejection_time = std::time::Duration::from_secs(3); - config - .keys - .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); - let config = Arc::new(config); - let mut sh = Server { - clients: Arc::new(Mutex::new(HashMap::new())), - id: 0, - }; - - let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = socket.local_addr().unwrap(); - - tokio::spawn(async move { - let (socket, _) = socket.accept().await.unwrap(); - let server = sh.new_client(socket.peer_addr().ok()); - server::run_stream(config, socket, server).await.unwrap(); - }); - - let mut config = client::Config::default(); - config.preferred = Preferred::COMPRESSED; - let config = Arc::new(config); - - dbg!(&addr); - let mut session = client::connect(config, addr, Client {}).await.unwrap(); - let authenticated = session - .authenticate_publickey( - std::env::var("USER").unwrap_or("user".to_owned()), - Arc::new(client_key), - ) - .await - .unwrap(); - assert!(authenticated); - let mut channel = session.channel_open_session().await.unwrap(); - - let data = &b"Hello, world!"[..]; - channel.data(data).await.unwrap(); - let msg = channel.wait().await.unwrap(); - match msg { - ChannelMsg::Data { data: msg_data } => { - assert_eq!(*data, *msg_data) - } - msg => panic!("Unexpected message {:?}", msg), - } - } - - #[derive(Clone)] - struct Server { - clients: Arc>>, - id: usize, - } - - impl server::Server for Server { - type Handler = Self; - fn new_client(&mut self, _: Option) -> Self { - let s = self.clone(); - self.id += 1; - s - } - } - - #[async_trait] - impl server::Handler for Server { - type Error = super::Error; - - async fn channel_open_session( - self, - channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - { - let mut clients = self.clients.lock().unwrap(); - clients.insert((self.id, channel.id()), session.handle()); - } - Ok((self, true, session)) - } - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - debug!("auth_publickey"); - Ok((self, server::Auth::Accept)) - } - async fn data( - self, - channel: ChannelId, - data: &[u8], - mut session: Session, - ) -> Result<(Self, Session), Self::Error> { - debug!("server data = {:?}", std::str::from_utf8(data)); - session.data(channel, CryptoVec::from_slice(data)); - Ok((self, session)) - } - } - - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = super::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - // println!("check_server_key: {:?}", server_public_key); - Ok((self, true)) - } - } -} - -#[cfg(test)] -use futures::Future; - -#[cfg(test)] -async fn test_session( - client_handler: CH, - server_handler: SH, - run_client: RC, - run_server: RS, -) where - RC: FnOnce(crate::client::Handle) -> F1 + Send + Sync + 'static, - RS: FnOnce(crate::server::Handle) -> F2 + Send + Sync + 'static, - F1: Future> + Send + Sync + 'static, - F2: Future + Send + Sync + 'static, - CH: crate::client::Handler + Send + Sync + 'static, - SH: crate::server::Handler + Send + Sync + 'static, -{ - use std::sync::Arc; - - use crate::*; - - let _ = env_logger::try_init(); - - let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); - let mut config = server::Config::default(); - config.inactivity_timeout = None; - config.auth_rejection_time = std::time::Duration::from_secs(3); - config - .keys - .push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); - let config = Arc::new(config); - let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = socket.local_addr().unwrap(); - - #[derive(Clone)] - struct Server {} - - let server_join = tokio::spawn(async move { - let (socket, _) = socket.accept().await.unwrap(); - - server::run_stream(config, socket, server_handler) - .await - .map_err(|_| ()) - .unwrap() - }); - - let client_join = tokio::spawn(async move { - let config = Arc::new(client::Config::default()); - let mut session = client::connect(config, addr, client_handler) - .await - .map_err(|_| ()) - .unwrap(); - let authenticated = session - .authenticate_publickey( - std::env::var("USER").unwrap_or("user".to_owned()), - Arc::new(client_key), - ) - .await - .unwrap(); - assert!(authenticated); - session - }); - - let (server_session, client_session) = tokio::join!(server_join, client_join); - let client_handle = tokio::spawn(run_client(client_session.unwrap())); - let server_handle = tokio::spawn(run_server(server_session.unwrap().handle())); - - let (server_session, client_session) = tokio::join!(server_handle, client_handle); - drop(client_session); - drop(server_session); -} - -#[cfg(test)] -mod test_channels { - use async_trait::async_trait; - use russh_cryptovec::CryptoVec; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - use crate::server::Session; - use crate::{client, server, test_session, Channel, ChannelId, ChannelMsg}; - - #[tokio::test] - async fn test_server_channels() { - #[derive(Debug)] - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = crate::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } - - async fn data( - self, - channel: ChannelId, - data: &[u8], - mut session: client::Session, - ) -> Result<(Self, client::Session), Self::Error> { - assert_eq!(data, &b"hello world!"[..]); - session.data(channel, CryptoVec::from_slice(&b"hey there!"[..])); - Ok((self, session)) - } - } - - struct ServerHandle { - did_auth: Option>, - } - - impl ServerHandle { - fn get_auth_waiter(&mut self) -> tokio::sync::oneshot::Receiver<()> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.did_auth = Some(tx); - rx - } - } - - #[async_trait] - impl server::Handler for ServerHandle { - type Error = crate::Error; - - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) - } - async fn auth_succeeded( - mut self, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - if let Some(a) = self.did_auth.take() { - a.send(()).unwrap(); - } - Ok((self, session)) - } - } - - let mut sh = ServerHandle { did_auth: None }; - let a = sh.get_auth_waiter(); - test_session( - Client {}, - sh, - |c| async move { c }, - |s| async move { - a.await.unwrap(); - let mut ch = s.channel_open_session().await.unwrap(); - ch.data(&b"hello world!"[..]).await.unwrap(); - - let msg = ch.wait().await.unwrap(); - if let ChannelMsg::Data { data } = msg { - assert_eq!(data.as_ref(), &b"hey there!"[..]); - } else { - panic!("Unexpected message {:?}", msg); - } - s - }, - ) - .await; - } - - #[tokio::test] - async fn test_channel_streams() { - #[derive(Debug)] - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = crate::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } - } - - struct ServerHandle { - channel: Option>>, - } - - impl ServerHandle { - fn get_channel_waiter( - &mut self, - ) -> tokio::sync::oneshot::Receiver> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - self.channel = Some(tx); - rx - } - } - - #[async_trait] - impl server::Handler for ServerHandle { - type Error = crate::Error; - - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) - } - - async fn channel_open_session( - mut self, - channel: Channel, - session: server::Session, - ) -> Result<(Self, bool, Session), Self::Error> { - if let Some(a) = self.channel.take() { - println!("channel open session {:?}", a); - a.send(channel).unwrap(); - } - Ok((self, true, session)) - } - } - - let mut sh = ServerHandle { channel: None }; - let scw = sh.get_channel_waiter(); - - test_session( - Client {}, - sh, - |client| async move { - let ch = client.channel_open_session().await.unwrap(); - let mut stream = ch.into_stream(); - stream.write_all(&b"request"[..]).await.unwrap(); - - let mut buf = Vec::new(); - stream.read_buf(&mut buf).await.unwrap(); - assert_eq!(&buf, &b"response"[..]); - - stream.write_all(&b"reply"[..]).await.unwrap(); - - client - }, - |server| async move { - let channel = scw.await.unwrap(); - let mut stream = channel.into_stream(); - - let mut buf = Vec::new(); - stream.read_buf(&mut buf).await.unwrap(); - assert_eq!(&buf, &b"request"[..]); - - stream.write_all(&b"response"[..]).await.unwrap(); - - buf.clear(); - - stream.read_buf(&mut buf).await.unwrap(); - assert_eq!(&buf, &b"reply"[..]); - - server - }, - ) - .await; - } - - #[tokio::test] - async fn test_channel_objects() { - #[derive(Debug)] - struct Client {} - - #[async_trait] - impl client::Handler for Client { - type Error = crate::Error; - - async fn check_server_key( - self, - _server_public_key: &russh_keys::key::PublicKey, - ) -> Result<(Self, bool), Self::Error> { - Ok((self, true)) - } - } - - struct ServerHandle {} - - impl ServerHandle {} - - #[async_trait] - impl server::Handler for ServerHandle { - type Error = crate::Error; - - async fn auth_publickey( - self, - _: &str, - _: &russh_keys::key::PublicKey, - ) -> Result<(Self, server::Auth), Self::Error> { - Ok((self, server::Auth::Accept)) - } - - async fn channel_open_session( - self, - mut channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - tokio::spawn(async move { - while let Some(msg) = channel.wait().await { - match msg { - ChannelMsg::Data { data } => { - channel.data(&data[..]).await.unwrap(); - channel.close().await.unwrap(); - break - } - _ => {} - } - } - }); - Ok((self, true, session)) - } - } - - let sh = ServerHandle {}; - test_session( - Client {}, - sh, - |c| async move { - let mut ch = c.channel_open_session().await.unwrap(); - ch.data(&b"hello world!"[..]).await.unwrap(); - - let msg = ch.wait().await.unwrap(); - if let ChannelMsg::Data { data } = msg { - assert_eq!(data.as_ref(), &b"hey there!"[..]); - } else { - panic!("Unexpected message {:?}", msg); - } - - let msg = ch.wait().await.unwrap(); - let ChannelMsg::Close = msg else { - panic!("Unexpected message {:?}", msg); - }; - - ch.close().await.unwrap(); - c - }, - |s| async move { s }, - ) - .await; +/// Returns `f(val)` if `val` it is [Some], or a forever pending [Future] if it is [None]. +pub(crate) fn future_or_pending, T>( + val: Option, + f: impl FnOnce(T) -> F, +) -> EitherFuture, F> { + match val { + None => EitherFuture::Left(core::future::pending()), + Some(x) => EitherFuture::Right(f(x)), } } diff --git a/russh/src/mac/mod.rs b/russh/src/mac/mod.rs index 5eada31b..d24339d0 100644 --- a/russh/src/mac/mod.rs +++ b/russh/src/mac/mod.rs @@ -14,13 +14,16 @@ //! //! This module exports cipher names for use with [Preferred]. use std::collections::HashMap; +use std::convert::TryFrom; use std::marker::PhantomData; +use delegate::delegate; use digest::typenum::{U20, U32, U64}; use hmac::Hmac; use once_cell::sync::Lazy; use sha1::Sha1; use sha2::{Sha256, Sha512}; +use ssh_encoding::Encode; use self::crypto::CryptoMacAlgorithm; use self::crypto_etm::CryptoEtmMacAlgorithm; @@ -52,6 +55,20 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + +impl TryFrom<&str> for Name { + type Error = (); + fn try_from(s: &str) -> Result { + MACS.keys().find(|x| x.0 == s).map(|x| **x).ok_or(()) + } +} + /// `none` pub const NONE: Name = Name("none"); /// `hmac-sha1` @@ -67,20 +84,30 @@ pub const HMAC_SHA256_ETM: Name = Name("hmac-sha2-256-etm@openssh.com"); /// `hmac-sha2-512-etm@openssh.com` pub const HMAC_SHA512_ETM: Name = Name("hmac-sha2-512-etm@openssh.com"); -static _NONE: NoMacAlgorithm = NoMacAlgorithm {}; -static _HMAC_SHA1: CryptoMacAlgorithm, U20> = +pub(crate) static _NONE: NoMacAlgorithm = NoMacAlgorithm {}; +pub(crate) static _HMAC_SHA1: CryptoMacAlgorithm, U20> = CryptoMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA256: CryptoMacAlgorithm, U32> = +pub(crate) static _HMAC_SHA256: CryptoMacAlgorithm, U32> = CryptoMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA512: CryptoMacAlgorithm, U64> = +pub(crate) static _HMAC_SHA512: CryptoMacAlgorithm, U64> = CryptoMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA1_ETM: CryptoEtmMacAlgorithm, U64> = +pub(crate) static _HMAC_SHA1_ETM: CryptoEtmMacAlgorithm, U20> = CryptoEtmMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA256_ETM: CryptoEtmMacAlgorithm, U64> = +pub(crate) static _HMAC_SHA256_ETM: CryptoEtmMacAlgorithm, U32> = CryptoEtmMacAlgorithm(PhantomData, PhantomData); -static _HMAC_SHA512_ETM: CryptoEtmMacAlgorithm, U64> = +pub(crate) static _HMAC_SHA512_ETM: CryptoEtmMacAlgorithm, U64> = CryptoEtmMacAlgorithm(PhantomData, PhantomData); +pub const ALL_MAC_ALGORITHMS: &[&Name] = &[ + &NONE, + &HMAC_SHA1, + &HMAC_SHA256, + &HMAC_SHA512, + &HMAC_SHA1_ETM, + &HMAC_SHA256_ETM, + &HMAC_SHA512_ETM, +]; + pub(crate) static MACS: Lazy> = Lazy::new(|| { let mut h: HashMap<&'static Name, &(dyn MacAlgorithm + Send + Sync)> = HashMap::new(); @@ -91,5 +118,6 @@ pub(crate) static MACS: Lazy Option { + order.get(seqno).map(|expected| expected == &msg_type) +} + +/// Validate a message+seqno against multiple strict kex order patterns +fn validate_msg_strict_kex_alt_order(msg_type: u8, seqno: usize, orders: &[&[u8]]) -> Option { + let mut valid = None; // did not match yet + for order in orders { + let result = validate_msg_strict_kex(msg_type, seqno, order); + valid = match (valid, result) { + // If we matched a valid msg, it's now valid forever + (Some(true), _) | (_, Some(true)) => Some(true), + // If we matched an invalid msg and we didn't find a valid one yet, it's now invalid + (None | Some(false), Some(false)) => Some(false), + // If the message was beyond the current pattern, no change + (x, None) => x, + }; + } + valid +} + +pub(crate) fn validate_client_msg_strict_kex(msg_type: u8, seqno: usize) -> Result<(), Error> { + if Some(false) + == validate_msg_strict_kex_alt_order( + msg_type, + seqno, + &[ + &[KEXINIT, KEX_ECDH_INIT, NEWKEYS], + &[KEXINIT, KEX_DH_GEX_REQUEST, KEX_DH_GEX_INIT, NEWKEYS], + ], + ) + { + return Err(strict_kex_violation(msg_type, seqno)); + } + Ok(()) +} + +pub(crate) fn validate_server_msg_strict_kex(msg_type: u8, seqno: usize) -> Result<(), Error> { + if Some(false) + == validate_msg_strict_kex_alt_order( + msg_type, + seqno, + &[ + &[KEXINIT, KEX_ECDH_REPLY, NEWKEYS], + &[KEXINIT, KEX_DH_GEX_GROUP, KEX_DH_GEX_REPLY, NEWKEYS], + ], + ) + { + return Err(strict_kex_violation(msg_type, seqno)); + } + Ok(()) +} + +const ALL_KEX_MESSAGES: &[u8] = &[ + KEXINIT, + KEX_ECDH_INIT, + KEX_ECDH_REPLY, + KEX_DH_GEX_GROUP, + KEX_DH_GEX_INIT, + KEX_DH_GEX_REPLY, + KEX_DH_GEX_REQUEST, + NEWKEYS, +]; + +pub(crate) fn is_kex_msg(msg: u8) -> bool { + ALL_KEX_MESSAGES.contains(&msg) +} diff --git a/russh/src/negotiation.rs b/russh/src/negotiation.rs index e8a7cf7e..b59c6531 100644 --- a/russh/src/negotiation.rs +++ b/russh/src/negotiation.rs @@ -12,50 +12,95 @@ // See the License for the specific language governing permissions and // limitations under the License. // -use std::str::from_utf8; +use std::borrow::Cow; -use rand::RngCore; use log::debug; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::{Encoding, Reader}; -use russh_keys::key; -use russh_keys::key::{KeyPair, PublicKey}; +use rand::RngCore; +use ssh_encoding::{Decode, Encode}; +use ssh_key::{Algorithm, EcdsaCurve, HashAlg, PrivateKey}; use crate::cipher::CIPHERS; -use crate::compression::*; -use crate::{cipher, kex, mac, msg, Error}; +use crate::helpers::NameList; +use crate::kex::{EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER}; +#[cfg(not(target_arch = "wasm32"))] +use crate::server::Config; +use crate::sshbuffer::PacketWriter; +use crate::{cipher, compression, kex, mac, msg, AlgorithmKind, CryptoVec, Error}; + +#[cfg(target_arch = "wasm32")] +/// WASM-only stub +pub struct Config { + keys: Vec, +} -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Names { pub kex: kex::Name, - pub key: key::Name, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] + pub key: Algorithm, pub cipher: cipher::Name, pub client_mac: mac::Name, pub server_mac: mac::Name, - pub server_compression: Compression, - pub client_compression: Compression, + pub server_compression: compression::Compression, + pub client_compression: compression::Compression, pub ignore_guessed: bool, + pub strict_kex: bool, } /// Lists of preferred algorithms. This is normally hard-coded into implementations. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Preferred { /// Preferred key exchange algorithms. - pub kex: &'static [kex::Name], - /// Preferred public key algorithms. - pub key: &'static [key::Name], + pub kex: Cow<'static, [kex::Name]>, + /// Preferred host & public key algorithms. + pub key: Cow<'static, [Algorithm]>, /// Preferred symmetric ciphers. - pub cipher: &'static [cipher::Name], + pub cipher: Cow<'static, [cipher::Name]>, /// Preferred MAC algorithms. - pub mac: &'static [mac::Name], + pub mac: Cow<'static, [mac::Name]>, /// Preferred compression algorithms. - pub compression: &'static [&'static str], + pub compression: Cow<'static, [compression::Name]>, +} + +pub(crate) fn is_key_compatible_with_algo(key: &PrivateKey, algo: &Algorithm) -> bool { + match algo { + // All RSA keys are compatible with all RSA based algos. + Algorithm::Rsa { .. } => key.algorithm().is_rsa(), + // Other keys have to match exactly + a => key.algorithm() == *a, + } +} + +impl Preferred { + pub(crate) fn possible_host_key_algos_for_keys( + &self, + available_host_keys: &[PrivateKey], + ) -> Vec { + self.key + .iter() + .filter(|n| { + available_host_keys + .iter() + .any(|k| is_key_compatible_with_algo(k, n)) + }) + .cloned() + .collect::>() + } } const SAFE_KEX_ORDER: &[kex::Name] = &[ kex::CURVE25519, kex::CURVE25519_PRE_RFC_8731, + kex::DH_GEX_SHA256, + kex::DH_G18_SHA512, + kex::DH_G17_SHA512, + kex::DH_G16_SHA512, + kex::DH_G15_SHA512, kex::DH_G14_SHA256, + kex::EXTENSION_SUPPORT_AS_CLIENT, + kex::EXTENSION_SUPPORT_AS_SERVER, + kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, ]; const CIPHER_ORDER: &[cipher::Name] = &[ @@ -75,31 +120,47 @@ const HMAC_ORDER: &[mac::Name] = &[ mac::HMAC_SHA1, ]; -impl Preferred { - #[cfg(feature = "openssl")] - pub const DEFAULT: Preferred = Preferred { - kex: SAFE_KEX_ORDER, - key: &[key::ED25519, key::RSA_SHA2_256, key::RSA_SHA2_512], - cipher: CIPHER_ORDER, - mac: HMAC_ORDER, - compression: &["none", "zlib", "zlib@openssh.com"], - }; +const COMPRESSION_ORDER: &[compression::Name] = &[ + compression::NONE, + #[cfg(feature = "flate2")] + compression::ZLIB, + #[cfg(feature = "flate2")] + compression::ZLIB_LEGACY, +]; - #[cfg(not(feature = "openssl"))] +impl Preferred { pub const DEFAULT: Preferred = Preferred { - kex: SAFE_KEX_ORDER, - key: &[key::ED25519], - cipher: CIPHER_ORDER, - mac: HMAC_ORDER, - compression: &["none", "zlib", "zlib@openssh.com"], + kex: Cow::Borrowed(SAFE_KEX_ORDER), + key: Cow::Borrowed(&[ + Algorithm::Ed25519, + Algorithm::Ecdsa { + curve: EcdsaCurve::NistP256, + }, + Algorithm::Ecdsa { + curve: EcdsaCurve::NistP384, + }, + Algorithm::Ecdsa { + curve: EcdsaCurve::NistP521, + }, + Algorithm::Rsa { + hash: Some(HashAlg::Sha512), + }, + Algorithm::Rsa { + hash: Some(HashAlg::Sha256), + }, + Algorithm::Rsa { hash: None }, + ]), + cipher: Cow::Borrowed(CIPHER_ORDER), + mac: Cow::Borrowed(HMAC_ORDER), + compression: Cow::Borrowed(COMPRESSION_ORDER), }; pub const COMPRESSED: Preferred = Preferred { - kex: SAFE_KEX_ORDER, - key: &[key::ED25519, key::RSA_SHA2_256, key::RSA_SHA2_512], - cipher: CIPHER_ORDER, - mac: HMAC_ORDER, - compression: &["zlib", "zlib@openssh.com", "none"], + kex: Cow::Borrowed(SAFE_KEX_ORDER), + key: Preferred::DEFAULT.key, + cipher: Cow::Borrowed(CIPHER_ORDER), + mac: Cow::Borrowed(HMAC_ORDER), + compression: Cow::Borrowed(COMPRESSION_ORDER), }; } @@ -109,142 +170,154 @@ impl Default for Preferred { } } -/// Named algorithms. -pub trait Named { - /// The name of this algorithm. - fn name(&self) -> &'static str; +pub(crate) fn parse_kex_algo_list(list: &str) -> Vec<&str> { + list.split(',').collect() } -impl Named for () { - fn name(&self) -> &'static str { - "" - } -} +pub(crate) trait Select { + fn is_server() -> bool; + + fn select + Clone>( + a: &[S], + b: &[&str], + kind: AlgorithmKind, + ) -> Result<(bool, S), Error>; + + /// `available_host_keys`, if present, is used to limit the host key algorithms to the ones we have keys for. + fn read_kex( + buffer: &[u8], + pref: &Preferred, + available_host_keys: Option<&[PrivateKey]>, + ) -> Result { + let Some(mut r) = &buffer.get(17..) else { + return Err(Error::Inconsistent); + }; -#[cfg(not(feature = "openssl"))] -use russh_keys::key::ED25519; -#[cfg(feature = "openssl")] -use russh_keys::key::{ED25519, SSH_RSA}; - -impl Named for PublicKey { - fn name(&self) -> &'static str { - match self { - PublicKey::Ed25519(_) => ED25519.0, - #[cfg(feature = "openssl")] - PublicKey::RSA { .. } => SSH_RSA.0, - } - } -} + // Key exchange -impl Named for KeyPair { - fn name(&self) -> &'static str { - match self { - KeyPair::Ed25519 { .. } => ED25519.0, - #[cfg(feature = "openssl")] - KeyPair::RSA { ref hash, .. } => hash.name().0, - } - } -} + let kex_string = String::decode(&mut r)?; + let (kex_both_first, kex_algorithm) = Self::select( + &pref.kex, + &parse_kex_algo_list(&kex_string), + AlgorithmKind::Kex, + )?; -pub trait Select { - fn select + Copy>(a: &[S], b: &[u8]) -> Option<(bool, S)>; + // Strict kex detection - fn read_kex(buffer: &[u8], pref: &Preferred) -> Result { - let mut r = buffer.reader(17); - let kex_string = r.read_string()?; - let (kex_both_first, kex_algorithm) = if let Some(x) = Self::select(pref.kex, kex_string) { - x + let strict_kex_requested = pref.kex.contains(if Self::is_server() { + &EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER } else { - debug!( - "Could not find common kex algorithm, other side only supports {:?}, we only support {:?}", - from_utf8(kex_string), - pref.kex - ); - return Err(Error::NoCommonKexAlgo); - }; + &EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT + }); + let strict_kex_provided = Self::select( + &[if Self::is_server() { + EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT + } else { + EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER + }], + &parse_kex_algo_list(&kex_string), + AlgorithmKind::Kex, + ) + .is_ok(); + + if strict_kex_requested && strict_kex_provided { + debug!("strict kex enabled") + } - let key_string = r.read_string()?; - let (key_both_first, key_algorithm) = if let Some(x) = Self::select(pref.key, key_string) { - x - } else { - debug!( - "Could not find common key algorithm, other side only supports {:?}, we only support {:?}", - from_utf8(key_string), - pref.key - ); - return Err(Error::NoCommonKeyAlgo); + // Host key + + let key_string = String::decode(&mut r)?; + let possible_host_key_algos = match available_host_keys { + Some(available_host_keys) => pref.possible_host_key_algos_for_keys(available_host_keys), + None => pref.key.iter().map(ToOwned::to_owned).collect::>(), }; - let cipher_string = r.read_string()?; - let cipher = Self::select(pref.cipher, cipher_string); - if cipher.is_none() { - debug!( - "Could not find common cipher, other side only supports {:?}, we only support {:?}", - from_utf8(cipher_string), - pref.cipher - ); - return Err(Error::NoCommonCipher); - } - r.read_string()?; // cipher server-to-client. - debug!("kex {}", line!()); - - let need_mac = cipher - .and_then(|x| CIPHERS.get(&x.1)) - .map(|x| x.needs_mac()) - .unwrap_or(false); - - let client_mac = if let Some((_, m)) = Self::select(pref.mac, r.read_string()?) { - m - } else if need_mac { - return Err(Error::NoCommonMac); - } else { - mac::NONE + let (key_both_first, key_algorithm) = Self::select( + &possible_host_key_algos[..], + &parse_kex_algo_list(&key_string), + AlgorithmKind::Key, + )?; + + // Cipher + + let cipher_string = String::decode(&mut r)?; + let (_cipher_both_first, cipher) = Self::select( + &pref.cipher, + &parse_kex_algo_list(&cipher_string), + AlgorithmKind::Cipher, + )?; + String::decode(&mut r)?; // cipher server-to-client. + + // MAC + + let need_mac = CIPHERS.get(&cipher).map(|x| x.needs_mac()).unwrap_or(false); + + let client_mac = match Self::select( + &pref.mac, + &parse_kex_algo_list(&String::decode(&mut r)?), + AlgorithmKind::Mac, + ) { + Ok((_, m)) => m, + Err(e) => { + if need_mac { + return Err(e); + } else { + mac::NONE + } + } }; - let server_mac = if let Some((_, m)) = Self::select(pref.mac, r.read_string()?) { - m - } else if need_mac { - return Err(Error::NoCommonMac); - } else { - mac::NONE + let server_mac = match Self::select( + &pref.mac, + &parse_kex_algo_list(&String::decode(&mut r)?), + AlgorithmKind::Mac, + ) { + Ok((_, m)) => m, + Err(e) => { + if need_mac { + return Err(e); + } else { + mac::NONE + } + } }; - debug!("kex {}", line!()); + // Compression + // client-to-server compression. - let client_compression = - if let Some((_, c)) = Self::select(pref.compression, r.read_string()?) { - Compression::from_string(c) - } else { - return Err(Error::NoCommonCompression); - }; - debug!("kex {}", line!()); + let client_compression = compression::Compression::new( + &Self::select( + &pref.compression, + &parse_kex_algo_list(&String::decode(&mut r)?), + AlgorithmKind::Compression, + )? + .1, + ); + // server-to-client compression. - let server_compression = - if let Some((_, c)) = Self::select(pref.compression, r.read_string()?) { - Compression::from_string(c) - } else { - return Err(Error::NoCommonCompression); - }; - debug!("client_compression = {:?}", client_compression); - r.read_string()?; // languages client-to-server - r.read_string()?; // languages server-to-client - - let follows = r.read_byte()? != 0; - match (cipher, follows) { - (Some((_, cipher)), fol) => { - Ok(Names { - kex: kex_algorithm, - key: key_algorithm, - cipher, - client_mac, - server_mac, - client_compression, - server_compression, - // Ignore the next packet if (1) it follows and (2) it's not the correct guess. - ignore_guessed: fol && !(kex_both_first && key_both_first), - }) - } - _ => Err(Error::KexInit), - } + let server_compression = compression::Compression::new( + &Self::select( + &pref.compression, + &parse_kex_algo_list(&String::decode(&mut r)?), + AlgorithmKind::Compression, + )? + .1, + ); + String::decode(&mut r)?; // languages client-to-server + String::decode(&mut r)?; // languages server-to-client + + let follows = u8::decode(&mut r)? != 0; + Ok(Names { + kex: kex_algorithm, + key: key_algorithm, + cipher, + client_mac, + server_mac, + client_compression, + server_compression, + // Ignore the next packet if (1) it follows and (2) it's not the correct guess. + ignore_guessed: follows && !(kex_both_first && key_both_first), + strict_kex: strict_kex_requested && strict_kex_provided, + }) } } @@ -252,65 +325,168 @@ pub struct Server; pub struct Client; impl Select for Server { - fn select + Copy>(server_list: &[S], client_list: &[u8]) -> Option<(bool, S)> { + fn is_server() -> bool { + true + } + + fn select + Clone>( + server_list: &[S], + client_list: &[&str], + kind: AlgorithmKind, + ) -> Result<(bool, S), Error> { let mut both_first_choice = true; - for c in client_list.split(|&x| x == b',') { - for &s in server_list { - if c == s.as_ref().as_bytes() { - return Some((both_first_choice, s)); + for c in client_list { + for s in server_list { + if c == &s.as_ref() { + return Ok((both_first_choice, s.clone())); } both_first_choice = false } } - None + Err(Error::NoCommonAlgo { + kind, + ours: server_list.iter().map(|x| x.as_ref().to_owned()).collect(), + theirs: client_list.iter().map(|x| (*x).to_owned()).collect(), + }) } } impl Select for Client { - fn select + Copy>(client_list: &[S], server_list: &[u8]) -> Option<(bool, S)> { + fn is_server() -> bool { + false + } + + fn select + Clone>( + client_list: &[S], + server_list: &[&str], + kind: AlgorithmKind, + ) -> Result<(bool, S), Error> { let mut both_first_choice = true; - for &c in client_list { - for s in server_list.split(|&x| x == b',') { - if s == c.as_ref().as_bytes() { - return Some((both_first_choice, c)); + for c in client_list { + for s in server_list { + if s == &c.as_ref() { + return Ok((both_first_choice, c.clone())); } both_first_choice = false } } - None + Err(Error::NoCommonAlgo { + kind, + ours: client_list.iter().map(|x| x.as_ref().to_owned()).collect(), + theirs: server_list.iter().map(|x| (*x).to_owned()).collect(), + }) } } -pub fn write_kex(prefs: &Preferred, buf: &mut CryptoVec, as_server: bool) -> Result<(), Error> { - // buf.clear(); - buf.push(msg::KEXINIT); - - let mut cookie = [0; 16]; - rand::thread_rng().fill_bytes(&mut cookie); +pub(crate) fn write_kex( + prefs: &Preferred, + writer: &mut PacketWriter, + server_config: Option<&Config>, +) -> Result { + writer.packet(|w| { + // buf.clear(); + msg::KEXINIT.encode(w)?; + + let mut cookie = [0; 16]; + rand::thread_rng().fill_bytes(&mut cookie); + for b in cookie { + b.encode(w)?; + } - buf.extend(&cookie); // cookie - buf.extend_list(prefs.kex.iter().filter(|k| { - **k != if as_server { - crate::kex::EXTENSION_SUPPORT_AS_CLIENT + NameList( + prefs + .kex + .iter() + .filter(|k| { + !(if server_config.is_some() { + [ + crate::kex::EXTENSION_SUPPORT_AS_CLIENT, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + ] + } else { + [ + crate::kex::EXTENSION_SUPPORT_AS_SERVER, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, + ] + }) + .contains(*k) + }) + .map(|x| x.as_ref().to_owned()) + .collect(), + ) + .encode(w)?; // kex algo + + if let Some(server_config) = server_config { + // Only advertise host key algorithms that we have keys for. + NameList( + prefs + .key + .iter() + .filter(|algo| { + server_config + .keys + .iter() + .any(|k| is_key_compatible_with_algo(k, algo)) + }) + .map(|x| x.to_string()) + .collect(), + ) + .encode(w)?; } else { - crate::kex::EXTENSION_SUPPORT_AS_SERVER + NameList(prefs.key.iter().map(ToString::to_string).collect()).encode(w)?; } - })); // kex algo - - buf.extend_list(prefs.key.iter()); - - buf.extend_list(prefs.cipher.iter()); // cipher client to server - buf.extend_list(prefs.cipher.iter()); // cipher server to client - - buf.extend_list(prefs.mac.iter()); // mac client to server - buf.extend_list(prefs.mac.iter()); // mac server to client - buf.extend_list(prefs.compression.iter()); // compress client to server - buf.extend_list(prefs.compression.iter()); // compress server to client - - buf.write_empty_list(); // languages client to server - buf.write_empty_list(); // languagesserver to client - buf.push(0); // doesn't follow - buf.extend(&[0, 0, 0, 0]); // reserved - Ok(()) + // cipher client to server + NameList( + prefs + .cipher + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(w)?; + + // cipher server to client + NameList( + prefs + .cipher + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(w)?; + + // mac client to server + NameList(prefs.mac.iter().map(|x| x.as_ref().to_string()).collect()).encode(w)?; + + // mac server to client + NameList(prefs.mac.iter().map(|x| x.as_ref().to_string()).collect()).encode(w)?; + + // compress client to server + NameList( + prefs + .compression + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(w)?; + + // compress server to client + NameList( + prefs + .compression + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(w)?; + + Vec::::new().encode(w)?; // languages client to server + Vec::::new().encode(w)?; // languages server to client + + 0u8.encode(w)?; // doesn't follow + 0u32.encode(w)?; // reserved + Ok(()) + }) } diff --git a/russh/src/parsing.rs b/russh/src/parsing.rs index 77f84e07..50861cd1 100644 --- a/russh/src/parsing.rs +++ b/russh/src/parsing.rs @@ -1,7 +1,7 @@ -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::{Encoding, Position}; +use ssh_encoding::{Decode, Encode, Reader}; -use crate::msg; +use crate::helpers::map_err; +use crate::{msg, CryptoVec}; #[derive(Debug)] pub struct OpenChannelMessage { @@ -12,30 +12,33 @@ pub struct OpenChannelMessage { } impl OpenChannelMessage { - pub fn parse(r: &mut Position) -> Result { + pub fn parse(r: &mut R) -> Result { // https://tools.ietf.org/html/rfc4254#section-5.1 - let typ = r.read_string().map_err(crate::Error::from)?; - let sender = r.read_u32().map_err(crate::Error::from)?; - let window = r.read_u32().map_err(crate::Error::from)?; - let maxpacket = r.read_u32().map_err(crate::Error::from)?; - - let typ = match typ { - b"session" => ChannelType::Session, - b"x11" => { - let originator_address = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let originator_port = r.read_u32().map_err(crate::Error::from)?; + let typ = map_err!(String::decode(r))?; + let sender = map_err!(u32::decode(r))?; + let window = map_err!(u32::decode(r))?; + let maxpacket = map_err!(u32::decode(r))?; + + let typ = match typ.as_str() { + "session" => ChannelType::Session, + "x11" => { + let originator_address = map_err!(String::decode(r))?; + let originator_port = map_err!(u32::decode(r))?; ChannelType::X11 { originator_address, originator_port, } } - b"direct-tcpip" => ChannelType::DirectTcpip(TcpChannelInfo::new(r)?), - b"forwarded-tcpip" => ChannelType::ForwardedTcpIp(TcpChannelInfo::new(r)?), - b"auth-agent@openssh.com" => ChannelType::AgentForward, - t => ChannelType::Unknown { typ: t.to_vec() }, + "direct-tcpip" => ChannelType::DirectTcpip(TcpChannelInfo::decode(r)?), + "direct-streamlocal@openssh.com" => { + ChannelType::DirectStreamLocal(StreamLocalChannelInfo::decode(r)?) + } + "forwarded-tcpip" => ChannelType::ForwardedTcpIp(TcpChannelInfo::decode(r)?), + "forwarded-streamlocal@openssh.com" => { + ChannelType::ForwardedStreamLocal(StreamLocalChannelInfo::decode(r)?) + } + "auth-agent@openssh.com" => ChannelType::AgentForward, + _ => ChannelType::Unknown { typ }, }; Ok(Self { @@ -53,34 +56,41 @@ impl OpenChannelMessage { sender_channel: u32, window_size: u32, packet_size: u32, - ) { + ) -> Result<(), crate::Error> { push_packet!(buffer, { - buffer.push(msg::CHANNEL_OPEN_CONFIRMATION); - buffer.push_u32_be(self.recipient_channel); // remote channel number. - buffer.push_u32_be(sender_channel); // our channel number. - buffer.push_u32_be(window_size); - buffer.push_u32_be(packet_size); + msg::CHANNEL_OPEN_CONFIRMATION.encode(buffer)?; + self.recipient_channel.encode(buffer)?; // remote channel number. + sender_channel.encode(buffer)?; // our channel number. + window_size.encode(buffer)?; + packet_size.encode(buffer)?; }); + Ok(()) } /// Pushes a failure message to the vec. - pub fn fail(&self, buffer: &mut CryptoVec, reason: u8, message: &[u8]) { + pub fn fail( + &self, + buffer: &mut CryptoVec, + reason: u8, + message: &[u8], + ) -> Result<(), crate::Error> { push_packet!(buffer, { - buffer.push(msg::CHANNEL_OPEN_FAILURE); - buffer.push_u32_be(self.recipient_channel); - buffer.push_u32_be(reason as u32); - buffer.extend_ssh_string(message); - buffer.extend_ssh_string(b"en"); + msg::CHANNEL_OPEN_FAILURE.encode(buffer)?; + self.recipient_channel.encode(buffer)?; + (reason as u32).encode(buffer)?; + message.encode(buffer)?; + "en".encode(buffer)?; }); + Ok(()) } /// Pushes an unknown type error to the vec. - pub fn unknown_type(&self, buffer: &mut CryptoVec) { + pub fn unknown_type(&self, buffer: &mut CryptoVec) -> Result<(), crate::Error> { self.fail( buffer, msg::SSH_OPEN_UNKNOWN_CHANNEL_TYPE, b"Unknown channel type", - ); + ) } } @@ -92,10 +102,12 @@ pub enum ChannelType { originator_port: u32, }, DirectTcpip(TcpChannelInfo), + DirectStreamLocal(StreamLocalChannelInfo), ForwardedTcpIp(TcpChannelInfo), + ForwardedStreamLocal(StreamLocalChannelInfo), AgentForward, Unknown { - typ: Vec, + typ: String, }, } @@ -107,16 +119,28 @@ pub struct TcpChannelInfo { pub originator_port: u32, } -impl TcpChannelInfo { - fn new(r: &mut Position) -> Result { - let host_to_connect = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let port_to_connect = r.read_u32().map_err(crate::Error::from)?; - let originator_address = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let originator_port = r.read_u32().map_err(crate::Error::from)?; +#[derive(Debug)] +pub struct StreamLocalChannelInfo { + pub socket_path: String, +} + +impl Decode for StreamLocalChannelInfo { + type Error = ssh_encoding::Error; + + fn decode(r: &mut impl Reader) -> Result { + let socket_path = String::decode(r)?.to_owned(); + Ok(Self { socket_path }) + } +} + +impl Decode for TcpChannelInfo { + type Error = ssh_encoding::Error; + + fn decode(r: &mut impl Reader) -> Result { + let host_to_connect = String::decode(r)?; + let port_to_connect = u32::decode(r)?; + let originator_address = String::decode(r)?; + let originator_port = u32::decode(r)?; Ok(Self { host_to_connect, @@ -135,12 +159,14 @@ pub(crate) struct ChannelOpenConfirmation { pub maximum_packet_size: u32, } -impl ChannelOpenConfirmation { - pub fn parse(r: &mut Position) -> Result { - let recipient_channel = r.read_u32().map_err(crate::Error::from)?; - let sender_channel = r.read_u32().map_err(crate::Error::from)?; - let initial_window_size = r.read_u32().map_err(crate::Error::from)?; - let maximum_packet_size = r.read_u32().map_err(crate::Error::from)?; +impl Decode for ChannelOpenConfirmation { + type Error = ssh_encoding::Error; + + fn decode(r: &mut impl Reader) -> Result { + let recipient_channel = u32::decode(r)?; + let sender_channel = u32::decode(r)?; + let initial_window_size = u32::decode(r)?; + let maximum_packet_size = u32::decode(r)?; Ok(Self { recipient_channel, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index abfa2555..262601ab 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -12,123 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. // +use core::str; use std::cell::RefCell; +use std::time::SystemTime; use auth::*; use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; +use cert::PublicKeyOrCertificate; use log::{debug, error, info, trace, warn}; -use negotiation::Select; -use russh_keys::encoding::{Encoding, Position, Reader}; -use russh_keys::key; -use russh_keys::key::Verify; -use tokio::sync::mpsc::unbounded_channel; +use msg; +use signature::Verifier; +use ssh_encoding::{Decode, Encode, Reader}; +use ssh_key::{PublicKey, Signature}; use tokio::time::Instant; -use {msg, negotiation}; use super::super::*; use super::*; +use crate::helpers::NameList; +use crate::map_err; use crate::msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; impl Session { /// Returns false iff a request was rejected. pub(crate) async fn server_read_encrypted( - mut self, - mut handler: H, - buf: &[u8], - ) -> Result<(H, Self), H::Error> { - #[allow(clippy::indexing_slicing)] // length checked - { - trace!( - "server_read_encrypted, buf = {:?}", - &buf[..buf.len().min(20)] - ); - } - // Either this packet is a KEXINIT, in which case we start a key re-exchange. - - #[allow(clippy::unwrap_used)] - let mut enc = self.common.encrypted.as_mut().unwrap(); - if buf.first() == Some(&msg::KEXINIT) { - debug!("Received rekeying request"); - // If we're not currently rekeying, but `buf` is a rekey request - if let Some(Kex::Init(kexinit)) = enc.rekey.take() { - enc.rekey = Some(kexinit.server_parse( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - buf, - &mut self.common.write_buffer, - )?); - } else if let Some(exchange) = enc.exchange.take() { - let kexinit = KexInit::received_rekey( - exchange, - negotiation::Server::read_kex(buf, &self.common.config.as_ref().preferred)?, - &enc.session_id, - ); - enc.rekey = Some(kexinit.server_parse( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - buf, - &mut self.common.write_buffer, - )?); - } - self.flush()?; - return Ok((handler, self)); - } - - match enc.rekey.take() { - Some(Kex::Dh(kexdh)) => { - enc.rekey = Some(kexdh.parse( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - buf, - &mut self.common.write_buffer, - )?); - self.flush()?; - return Ok((handler, self)); - } - Some(Kex::Keys(newkeys)) => { - if buf.first() != Some(&msg::NEWKEYS) { - return Err(Error::Kex.into()); - } - self.common.write_buffer.bytes = 0; - enc.last_rekey = std::time::Instant::now(); - - // Ok, NEWKEYS received, now encrypted. - enc.flush_all_pending(); - let mut pending = std::mem::take(&mut self.pending_reads); - for p in pending.drain(..) { - let (h, s) = self.process_packet(handler, &p).await?; - handler = h; - self = s; - } - self.pending_reads = pending; - self.pending_len = 0; - self.common.newkeys(newkeys); - self.flush()?; - return Ok((handler, self)); - } - Some(Kex::Init(k)) => { - enc.rekey = Some(Kex::Init(k)); - self.pending_len += buf.len() as u32; - if self.pending_len > 2 * self.target_window_size { - return Err(Error::Pending.into()); - } - self.pending_reads.push(CryptoVec::from_slice(buf)); - return Ok((handler, self)); - } - rek => { - trace!("rek = {:?}", rek); - enc.rekey = rek - } - } - self.process_packet(handler, buf).await + &mut self, + handler: &mut H, + pkt: &mut IncomingSshPacket, + ) -> Result<(), H::Error> { + self.process_packet(handler, &pkt.buffer).await } - async fn process_packet( - mut self, - mut handler: H, + pub(crate) async fn process_packet( + &mut self, + handler: &mut H, buf: &[u8], - ) -> Result<(H, Self), H::Error> { + ) -> Result<(), H::Error> { let rejection_wait_until = tokio::time::Instant::now() + self.common.config.auth_rejection_time; let initial_none_rejection_wait_until = if self.common.auth_attempts == 0 { @@ -142,103 +62,107 @@ impl Session { rejection_wait_until }; - #[allow(clippy::unwrap_used)] - let mut enc = self.common.encrypted.as_mut().unwrap(); + let Some(enc) = self.common.encrypted.as_mut() else { + return Err(Error::Inconsistent.into()); + }; + // If we've successfully read a packet. - match enc.state { - EncryptedState::WaitingAuthServiceRequest { - ref mut accepted, .. - } if buf.first() == Some(&msg::SERVICE_REQUEST) => { - let mut r = buf.reader(1); - let request = r.read_string().map_err(crate::Error::from)?; - debug!("request: {:?}", std::str::from_utf8(request)); - if request == b"ssh-userauth" { + match (&mut enc.state, buf.split_first()) { + ( + EncryptedState::WaitingAuthServiceRequest { + ref mut accepted, .. + }, + Some((&msg::SERVICE_REQUEST, mut r)), + ) => { + let request = map_err!(String::decode(&mut r))?; + debug!("request: {:?}", request); + if request == "ssh-userauth" { let auth_request = server_accept_service( - self.common.config.as_ref().auth_banner, - self.common.config.as_ref().methods, + handler.authentication_banner().await?, + self.common.config.as_ref().methods.clone(), &mut enc.write, - ); + )?; *accepted = true; enc.state = EncryptedState::WaitingAuthRequest(auth_request); } - Ok((handler, self)) + Ok(()) } - EncryptedState::WaitingAuthRequest(_) - if buf.first() == Some(&msg::USERAUTH_REQUEST) => - { - handler = enc - .server_read_auth_request( - rejection_wait_until, - initial_none_rejection_wait_until, - handler, - buf, - &mut self.common.auth_user, - ) - .await?; + (EncryptedState::WaitingAuthRequest(_), Some((&msg::USERAUTH_REQUEST, mut r))) => { + enc.server_read_auth_request( + rejection_wait_until, + initial_none_rejection_wait_until, + handler, + buf, + &mut r, + &mut self.common.auth_user, + ) + .await?; self.common.auth_attempts += 1; if let EncryptedState::InitCompression = enc.state { enc.client_compression.init_decompress(&mut enc.decompress); - handler.auth_succeeded(self).await - } else { - Ok((handler, self)) + handler.auth_succeeded(self).await?; } + Ok(()) } - EncryptedState::WaitingAuthRequest(ref mut auth) - if buf.first() == Some(&msg::USERAUTH_INFO_RESPONSE) => - { - let (h, resp) = read_userauth_info_response( + ( + EncryptedState::WaitingAuthRequest(ref mut auth), + Some((&msg::USERAUTH_INFO_RESPONSE, mut r)), + ) => { + let resp = read_userauth_info_response( rejection_wait_until, handler, &mut enc.write, auth, - &mut self.common.auth_user, - buf, + &self.common.auth_user, + &mut r, ) .await?; - handler = h; if resp { enc.state = EncryptedState::InitCompression; enc.client_compression.init_decompress(&mut enc.decompress); handler.auth_succeeded(self).await } else { - Ok((handler, self)) + Ok(()) } } - EncryptedState::InitCompression => { - enc.server_compression.init_compress(&mut enc.compress); + (EncryptedState::InitCompression, Some((msg, mut r))) => { + enc.server_compression + .init_compress(self.common.packet_writer.compress()); enc.state = EncryptedState::Authenticated; - self.server_read_authenticated(handler, buf).await + self.server_read_authenticated(handler, *msg, &mut r).await + } + (EncryptedState::Authenticated, Some((msg, mut r))) => { + self.server_read_authenticated(handler, *msg, &mut r).await } - EncryptedState::Authenticated => self.server_read_authenticated(handler, buf).await, - _ => Ok((handler, self)), + _ => Ok(()), } } } fn server_accept_service( - banner: Option<&str>, + banner: Option, methods: MethodSet, buffer: &mut CryptoVec, -) -> AuthRequest { +) -> Result { push_packet!(buffer, { buffer.push(msg::SERVICE_ACCEPT); - buffer.extend_ssh_string(b"ssh-userauth"); + "ssh-userauth".encode(buffer)?; }); if let Some(banner) = banner { push_packet!(buffer, { buffer.push(msg::USERAUTH_BANNER); - buffer.extend_ssh_string(banner.as_bytes()); - buffer.extend_ssh_string(b""); + banner.encode(buffer)?; + "".encode(buffer)?; }) } - AuthRequest { + Ok(AuthRequest { methods, partial_success: false, // not used immediately anway. current: None, rejection_count: 0, - } + }) } impl Encrypted { @@ -247,25 +171,19 @@ impl Encrypted { &mut self, mut until: Instant, initial_auth_until: Instant, - mut handler: H, - buf: &[u8], + handler: &mut H, + original_packet: &[u8], + r: &mut &[u8], auth_user: &mut String, - ) -> Result { + ) -> Result<(), H::Error> { // https://tools.ietf.org/html/rfc4252#section-5 - let mut r = buf.reader(1); - let user = r.read_string().map_err(crate::Error::from)?; - let user = std::str::from_utf8(user).map_err(crate::Error::from)?; - let service_name = r.read_string().map_err(crate::Error::from)?; - let method = r.read_string().map_err(crate::Error::from)?; - debug!( - "name: {:?} {:?} {:?}", - user, - std::str::from_utf8(service_name), - std::str::from_utf8(method) - ); + let user = map_err!(String::decode(r))?; + let service_name = map_err!(String::decode(r))?; + let method = map_err!(String::decode(r))?; + debug!("name: {user:?} {service_name:?} {method:?}",); - if service_name == b"ssh-connection" { - if method == b"password" { + if service_name == "ssh-connection" { + if method == "password" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -273,25 +191,40 @@ impl Encrypted { unreachable!() }; auth_user.clear(); - auth_user.push_str(user); - r.read_byte().map_err(crate::Error::from)?; - let password = r.read_string().map_err(crate::Error::from)?; - let password = std::str::from_utf8(password).map_err(crate::Error::from)?; - let (handler, auth) = handler.auth_password(user, password).await?; + auth_user.push_str(&user); + map_err!(u8::decode(r))?; + let password = map_err!(String::decode(r))?; + let auth = handler.auth_password(&user, &password).await?; if let Auth::Accept = auth { server_auth_request_success(&mut self.write); self.state = EncryptedState::InitCompression; } else { auth_user.clear(); - auth_request.methods -= MethodSet::PASSWORD; + if let Auth::Reject { + proceed_with_methods: Some(proceed_with_methods), + partial_success, + } = auth + { + auth_request.methods = proceed_with_methods; + auth_request.partial_success = partial_success; + } else { + auth_request.methods.remove(MethodKind::Password); + } auth_request.partial_success = false; - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } - Ok(handler) - } else if method == b"publickey" { - self.server_read_auth_request_pk(until, handler, buf, auth_user, user, r) - .await - } else if method == b"none" { + Ok(()) + } else if method == "publickey" { + self.server_read_auth_request_pk( + until, + handler, + original_packet, + auth_user, + &user, + r, + ) + .await + } else if method == "none" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -299,22 +232,29 @@ impl Encrypted { unreachable!() }; - if method == b"none" { - until = initial_auth_until - } + until = initial_auth_until; - let (handler, auth) = handler.auth_none(user).await?; + let auth = handler.auth_none(&user).await?; if let Auth::Accept = auth { server_auth_request_success(&mut self.write); self.state = EncryptedState::InitCompression; } else { auth_user.clear(); - auth_request.methods -= MethodSet::NONE; + if let Auth::Reject { + proceed_with_methods: Some(proceed_with_methods), + partial_success, + } = auth + { + auth_request.methods = proceed_with_methods; + auth_request.partial_success = partial_success; + } else { + auth_request.methods.remove(MethodKind::None); + } auth_request.partial_success = false; - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } - Ok(handler) - } else if method == b"keyboard-interactive" { + Ok(()) + } else if method == "keyboard-interactive" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -322,22 +262,20 @@ impl Encrypted { unreachable!() }; auth_user.clear(); - auth_user.push_str(user); - let _ = r.read_string().map_err(crate::Error::from)?; // language_tag, deprecated. - let submethods = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + auth_user.push_str(&user); + let _ = map_err!(String::decode(r))?; // language_tag, deprecated. + let submethods = map_err!(String::decode(r))?; debug!("{:?}", submethods); auth_request.current = Some(CurrentRequest::KeyboardInteractive { submethods: submethods.to_string(), }); - let (h, auth) = handler - .auth_keyboard_interactive(user, submethods, None) + let auth = handler + .auth_keyboard_interactive(&user, &submethods, None) .await?; - handler = h; if reply_userauth_info_response(until, auth_request, &mut self.write, auth).await? { self.state = EncryptedState::InitCompression } - Ok(handler) + Ok(()) } else { // Other methods of the base specification are insecure or optional. let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state @@ -346,8 +284,8 @@ impl Encrypted { } else { unreachable!() }; - reject_auth_request(until, &mut self.write, auth_request).await; - Ok(handler) + reject_auth_request(until, &mut self.write, auth_request).await?; + Ok(()) } } else { // Unknown service @@ -364,27 +302,64 @@ impl Encrypted { async fn server_read_auth_request_pk( &mut self, until: Instant, - mut handler: H, - buf: &[u8], + handler: &mut H, + original_packet: &[u8], auth_user: &mut String, user: &str, - mut r: Position<'_>, - ) -> Result { + r: &mut &[u8], + ) -> Result<(), H::Error> { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a } else { unreachable!() }; - let is_real = r.read_byte().map_err(crate::Error::from)?; - let pubkey_algo = r.read_string().map_err(crate::Error::from)?; - let pubkey_key = r.read_string().map_err(crate::Error::from)?; - debug!("algo: {:?}, key: {:?}", pubkey_algo, pubkey_key); - match key::PublicKey::parse(pubkey_algo, pubkey_key) { - Ok(mut pubkey) => { + + let is_real = map_err!(u8::decode(r))?; + + let pubkey_algo = map_err!(String::decode(r))?; + let pubkey_key = map_err!(Bytes::decode(r))?; + let key_or_cert = PublicKeyOrCertificate::decode(&pubkey_algo, &pubkey_key); + + // Parse the public key or certificate + match key_or_cert { + Ok(pk_or_cert) => { debug!("is_real = {:?}", is_real); + // Handle certificates specifically + let pubkey = match pk_or_cert { + PublicKeyOrCertificate::PublicKey { ref key, .. } => key.clone(), + PublicKeyOrCertificate::Certificate(ref cert) => { + // Validate certificate expiration + let now = SystemTime::now(); + if now < cert.valid_after_time() || now > cert.valid_before_time() { + warn!("Certificate is expired or not yet valid"); + reject_auth_request(until, &mut self.write, auth_request).await?; + return Ok(()); + } + + // Verify the certificate’s signature + if cert.verify_signature().is_err() { + warn!("Certificate signature is invalid"); + reject_auth_request(until, &mut self.write, auth_request).await?; + return Ok(()); + } + + // Use certificate's public key for authentication + PublicKey::new(cert.public_key().clone(), "") + } + }; + if is_real != 0 { - let pos0 = r.position; + // SAFETY: both original_packet and pos0 are coming + // from the same allocation (pos0 is derived from + // a slice of the original_packet) + let sig_init_buffer = { + let pos0 = r.as_ptr(); + let init_len = unsafe { pos0.offset_from(original_packet.as_ptr()) }; + #[allow(clippy::indexing_slicing)] // length checked + &original_packet[0..init_len as usize] + }; + let sent_pk_ok = if let Some(CurrentRequest::PublicKey { sent_pk_ok, .. }) = auth_request.current { @@ -393,66 +368,82 @@ impl Encrypted { false }; - let signature = r.read_string().map_err(crate::Error::from)?; - debug!("signature = {:?}", signature); - let mut s = signature.reader(0); - let algo_ = s.read_string().map_err(crate::Error::from)?; - pubkey.set_algorithm(algo_); - debug!("algo_: {:?}", algo_); - let sig = s.read_string().map_err(crate::Error::from)?; - #[allow(clippy::indexing_slicing)] // length checked - let init = &buf[0..pos0]; + let encoded_signature = map_err!(Vec::::decode(r))?; + + let sig = map_err!(Signature::decode(&mut encoded_signature.as_slice()))?; let is_valid = if sent_pk_ok && user == auth_user { true } else if auth_user.is_empty() { auth_user.clear(); auth_user.push_str(user); - let (h, auth) = handler.auth_publickey(user, &pubkey).await?; - handler = h; + let auth = handler.auth_publickey_offered(user, &pubkey).await?; auth == Auth::Accept } else { false }; + if is_valid { let session_id = self.session_id.as_ref(); - #[allow(clippy::blocks_in_if_conditions)] // length checked + #[allow(clippy::blocks_in_conditions)] if SIGNATURE_BUFFER.with(|buf| { let mut buf = buf.borrow_mut(); buf.clear(); - buf.extend_ssh_string(session_id); - buf.extend(init); - // Verify signature. - pubkey.verify_client_auth(&buf, sig) - }) { + map_err!(session_id.encode(&mut *buf))?; + buf.extend(sig_init_buffer); + + Ok(Verifier::verify(&pubkey, &buf, &sig).is_ok()) + })? { debug!("signature verified"); - server_auth_request_success(&mut self.write); - self.state = EncryptedState::InitCompression; + let auth = match pk_or_cert { + PublicKeyOrCertificate::PublicKey { ref key, .. } => { + handler.auth_publickey(user, key).await? + } + PublicKeyOrCertificate::Certificate(ref cert) => { + handler.auth_openssh_certificate(user, cert).await? + } + }; + + if auth == Auth::Accept { + server_auth_request_success(&mut self.write); + self.state = EncryptedState::InitCompression; + } else { + if let Auth::Reject { + proceed_with_methods: Some(proceed_with_methods), + partial_success, + } = auth + { + auth_request.methods = proceed_with_methods; + auth_request.partial_success = partial_success; + } + auth_request.partial_success = false; + auth_user.clear(); + reject_auth_request(until, &mut self.write, auth_request).await?; + } } else { debug!("signature wrong"); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } } else { - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } - Ok(handler) + Ok(()) } else { auth_user.clear(); auth_user.push_str(user); - let (h, auth) = handler.auth_publickey(user, &pubkey).await?; - handler = h; + let auth = handler.auth_publickey_offered(user, &pubkey).await?; match auth { Auth::Accept => { let mut public_key = CryptoVec::new(); - public_key.extend(pubkey_key); + public_key.extend(&pubkey_key); let mut algo = CryptoVec::new(); - algo.extend(pubkey_algo); + algo.extend(pubkey_algo.as_bytes()); debug!("pubkey_key: {:?}", pubkey_key); push_packet!(self.write, { self.write.push(msg::USERAUTH_PK_OK); - self.write.extend_ssh_string(pubkey_algo); - self.write.extend_ssh_string(pubkey_key); + map_err!(pubkey_algo.encode(&mut self.write))?; + map_err!(pubkey_key.encode(&mut self.write))?; }); auth_request.current = Some(CurrentRequest::PublicKey { @@ -464,26 +455,30 @@ impl Encrypted { auth => { if let Auth::Reject { proceed_with_methods: Some(proceed_with_methods), + partial_success, } = auth { auth_request.methods = proceed_with_methods; + auth_request.partial_success = partial_success; } auth_request.partial_success = false; auth_user.clear(); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } } - Ok(handler) + Ok(()) } } - Err(e) => { - if let russh_keys::Error::CouldNotReadKey = e { - reject_auth_request(until, &mut self.write, auth_request).await; - Ok(handler) - } else { - Err(crate::Error::from(e).into()) + Err(e) => match e { + ssh_key::Error::AlgorithmUnknown + | ssh_key::Error::AlgorithmUnsupported { .. } + | ssh_key::Error::CertificateValidation { .. } => { + debug!("public key error: {e}"); + reject_auth_request(until, &mut self.write, auth_request).await?; + Ok(()) } - } + e => Err(crate::Error::from(e).into()), + }, } } } @@ -492,17 +487,18 @@ async fn reject_auth_request( until: Instant, write: &mut CryptoVec, auth_request: &mut AuthRequest, -) { +) -> Result<(), Error> { debug!("rejecting {:?}", auth_request); push_packet!(write, { write.push(msg::USERAUTH_FAILURE); - write.extend_list(auth_request.methods.into_iter()); + NameList::from(&auth_request.methods).encode(write)?; write.push(auth_request.partial_success as u8); }); auth_request.current = None; auth_request.rejection_count += 1; debug!("packet pushed"); - tokio::time::sleep_until(until).await + tokio::time::sleep_until(until).await; + Ok(()) } fn server_auth_request_success(buffer: &mut CryptoVec) { @@ -511,29 +507,32 @@ fn server_auth_request_success(buffer: &mut CryptoVec) { }) } -async fn read_userauth_info_response( +async fn read_userauth_info_response( until: Instant, - mut handler: H, + handler: &mut H, write: &mut CryptoVec, auth_request: &mut AuthRequest, - user: &mut str, - b: &[u8], -) -> Result<(H, bool), H::Error> { + user: &str, + r: &mut R, +) -> Result { if let Some(CurrentRequest::KeyboardInteractive { ref submethods }) = auth_request.current { - let mut r = b.reader(1); - let n = r.read_u32().map_err(crate::Error::from)?; - let response = Response { pos: r, n }; - let (h, auth) = handler - .auth_keyboard_interactive(user, submethods, Some(response)) + let n = map_err!(u32::decode(r))?; + + let mut responses = Vec::with_capacity(n as usize); + for _ in 0..n { + responses.push(Bytes::decode(r).ok()) + } + + let auth = handler + .auth_keyboard_interactive(user, submethods, Some(Response(&mut responses.into_iter()))) .await?; - handler = h; let resp = reply_userauth_info_response(until, auth_request, write, auth) .await .map_err(H::Error::from)?; - Ok((handler, resp)) + Ok(resp) } else { - reject_auth_request(until, write, auth_request).await; - Ok((handler, false)) + reject_auth_request(until, write, auth_request).await?; + Ok(false) } } @@ -550,12 +549,13 @@ async fn reply_userauth_info_response( } Auth::Reject { proceed_with_methods, + partial_success, } => { if let Some(proceed_with_methods) = proceed_with_methods { auth_request.methods = proceed_with_methods; } - auth_request.partial_success = false; - reject_auth_request(until, write, auth_request).await; + auth_request.partial_success = partial_success; + reject_auth_request(until, write, auth_request).await?; Ok(false) } Auth::Partial { @@ -564,16 +564,17 @@ async fn reply_userauth_info_response( prompts, } => { push_packet!(write, { - write.push(msg::USERAUTH_INFO_REQUEST); - write.extend_ssh_string(name.as_bytes()); - write.extend_ssh_string(instructions.as_bytes()); - write.extend_ssh_string(b""); // lang, should be empty - write.push_u32_be(prompts.len() as u32); + msg::USERAUTH_INFO_REQUEST.encode(write)?; + name.as_ref().encode(write)?; + instructions.as_ref().encode(write)?; + "".encode(write)?; // lang, should be empty + prompts.len().encode(write)?; for &(ref a, b) in prompts.iter() { - write.extend_ssh_string(a.as_bytes()); - write.push(b as u8); + a.as_ref().encode(write)?; + (b as u8).encode(write)?; } - }); + Ok::<(), crate::Error>(()) + })?; Ok(false) } Auth::UnsupportedMethod => unreachable!(), @@ -581,26 +582,19 @@ async fn reply_userauth_info_response( } impl Session { - async fn server_read_authenticated( - mut self, - mut handler: H, - buf: &[u8], - ) -> Result<(H, Self), H::Error> { - #[allow(clippy::indexing_slicing)] // length checked - { - trace!( - "authenticated buf = {:?}", - &buf[..std::cmp::min(buf.len(), 100)] - ); - } - match buf.first() { - Some(&msg::CHANNEL_OPEN) => self - .server_handle_channel_open(handler, buf) + async fn server_read_authenticated( + &mut self, + handler: &mut H, + msg: u8, + r: &mut R, + ) -> Result<(), H::Error> { + match msg { + msg::CHANNEL_OPEN => self + .server_handle_channel_open(handler, r) .await - .map(|(h, _, s)| (h, s)), - Some(&msg::CHANNEL_CLOSE) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + .map(|_| ()), + msg::CHANNEL_CLOSE => { + let channel_num = map_err!(ChannelId::decode(r))?; if let Some(ref mut enc) = self.common.encrypted { enc.channels.remove(&channel_num); } @@ -608,30 +602,28 @@ impl Session { debug!("handler.channel_close {:?}", channel_num); handler.channel_close(channel_num, self).await } - Some(&msg::CHANNEL_EOF) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + msg::CHANNEL_EOF => { + let channel_num = map_err!(ChannelId::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - chan.send(ChannelMsg::Eof).unwrap_or(()) + chan.send(ChannelMsg::Eof).await.unwrap_or(()) } debug!("handler.channel_eof {:?}", channel_num); handler.channel_eof(channel_num, self).await } - Some(&msg::CHANNEL_EXTENDED_DATA) | Some(&msg::CHANNEL_DATA) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + msg::CHANNEL_EXTENDED_DATA | msg::CHANNEL_DATA => { + let channel_num = map_err!(ChannelId::decode(r))?; - let ext = if buf.first() == Some(&msg::CHANNEL_DATA) { + let ext = if msg == msg::CHANNEL_DATA { None } else { - Some(r.read_u32().map_err(crate::Error::from)?) + Some(map_err!(u32::decode(r))?) }; trace!("handler.data {:?} {:?}", ext, channel_num); - let data = r.read_string().map_err(crate::Error::from)?; + let data = map_err!(Bytes::decode(r))?; let target = self.target_window_size; if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, data, target) { + if enc.adjust_window_size(channel_num, &data, target)? { let window = handler.adjust_window(channel_num, self.target_window_size); if window > 0 { self.target_window_size = window @@ -643,26 +635,27 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::ExtendedData { ext, - data: CryptoVec::from_slice(data), + data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } - handler.extended_data(channel_num, ext, data, self).await + handler.extended_data(channel_num, ext, &data, self).await } else { if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(data), + data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } - handler.data(channel_num, data, self).await + handler.data(channel_num, &data, self).await } } - Some(&msg::CHANNEL_WINDOW_ADJUST) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let amount = r.read_u32().map_err(crate::Error::from)?; + msg::CHANNEL_WINDOW_ADJUST => { + let channel_num = map_err!(ChannelId::decode(r))?; + let amount = map_err!(u32::decode(r))?; let mut new_size = 0; if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel_num) { @@ -673,20 +666,22 @@ impl Session { } } if let Some(ref mut enc) = self.common.encrypted { - enc.flush_pending(channel_num); + enc.flush_pending(channel_num)?; } if let Some(chan) = self.channels.get(&channel_num) { + chan.window_size().update(new_size).await; + chan.send(ChannelMsg::WindowAdjusted { new_size }) + .await .unwrap_or(()) } debug!("handler.window_adjusted {:?}", channel_num); handler.window_adjusted(channel_num, new_size, self).await } - Some(&msg::CHANNEL_OPEN_CONFIRMATION) => { + msg::CHANNEL_OPEN_CONFIRMATION => { debug!("channel_open_confirmation"); - let mut reader = buf.reader(1); - let msg = ChannelOpenConfirmation::parse(&mut reader)?; + let msg = map_err!(ChannelOpenConfirmation::decode(r))?; let local_id = ChannelId(msg.recipient_channel); if let Some(ref mut enc) = self.common.encrypted { @@ -707,6 +702,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {:?}", local_id); @@ -721,29 +717,26 @@ impl Session { .await } - Some(&msg::CHANNEL_REQUEST) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let req_type = r.read_string().map_err(crate::Error::from)?; - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + msg::CHANNEL_REQUEST => { + let channel_num = map_err!(ChannelId::decode(r))?; + let req_type = map_err!(String::decode(r))?; + let wants_reply = map_err!(u8::decode(r))?; if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel_num) { channel.wants_reply = wants_reply != 0; } } - match req_type { - b"pty-req" => { - let term = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let col_width = r.read_u32().map_err(crate::Error::from)?; - let row_height = r.read_u32().map_err(crate::Error::from)?; - let pix_width = r.read_u32().map_err(crate::Error::from)?; - let pix_height = r.read_u32().map_err(crate::Error::from)?; + match req_type.as_str() { + "pty-req" => { + let term = map_err!(String::decode(r))?; + let col_width = map_err!(u32::decode(r))?; + let row_height = map_err!(u32::decode(r))?; + let pix_width = map_err!(u32::decode(r))?; + let pix_height = map_err!(u32::decode(r))?; let mut modes = [(Pty::TTY_OP_END, 0); 130]; let mut i = 0; { - let mode_string = r.read_string().map_err(crate::Error::from)?; + let mode_string = map_err!(Bytes::decode(r))?; while 5 * i < mode_string.len() { #[allow(clippy::indexing_slicing)] // length checked let code = mode_string[5 * i]; @@ -768,15 +761,17 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestPty { - want_reply: true, - term: term.into(), - col_width, - row_height, - pix_width, - pix_height, - terminal_modes: modes.into(), - }); + let _ = chan + .send(ChannelMsg::RequestPty { + want_reply: true, + term: term.clone(), + col_width, + row_height, + pix_width, + pix_height, + terminal_modes: modes.into(), + }) + .await; } debug!("handler.pty_request {:?}", channel_num); @@ -784,7 +779,7 @@ impl Session { handler .pty_request( channel_num, - term, + &term, col_width, row_height, pix_width, @@ -794,118 +789,121 @@ impl Session { ) .await } - b"x11-req" => { - let single_connection = r.read_byte().map_err(crate::Error::from)? != 0; - let x11_auth_protocol = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let x11_auth_cookie = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let x11_screen_number = r.read_u32().map_err(crate::Error::from)?; + "x11-req" => { + let single_connection = map_err!(u8::decode(r))? != 0; + let x11_auth_protocol = map_err!(String::decode(r))?; + let x11_auth_cookie = map_err!(String::decode(r))?; + let x11_screen_number = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestX11 { - want_reply: true, - single_connection, - x11_authentication_cookie: x11_auth_cookie.into(), - x11_authentication_protocol: x11_auth_protocol.into(), - x11_screen_number, - }); + let _ = chan + .send(ChannelMsg::RequestX11 { + want_reply: true, + single_connection, + x11_authentication_cookie: x11_auth_cookie.clone(), + x11_authentication_protocol: x11_auth_protocol.clone(), + x11_screen_number, + }) + .await; } debug!("handler.x11_request {:?}", channel_num); handler .x11_request( channel_num, single_connection, - x11_auth_protocol, - x11_auth_cookie, + &x11_auth_protocol, + &x11_auth_cookie, x11_screen_number, self, ) .await } - b"env" => { - let env_variable = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let env_value = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + "env" => { + let env_variable = map_err!(String::decode(r))?; + let env_value = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::SetEnv { - want_reply: true, - variable_name: env_variable.into(), - variable_value: env_value.into(), - }); + let _ = chan + .send(ChannelMsg::SetEnv { + want_reply: true, + variable_name: env_variable.clone(), + variable_value: env_value.clone(), + }) + .await; } debug!("handler.env_request {:?}", channel_num); handler - .env_request(channel_num, env_variable, env_value, self) + .env_request(channel_num, &env_variable, &env_value, self) .await } - b"shell" => { + "shell" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); + let _ = chan + .send(ChannelMsg::RequestShell { want_reply: true }) + .await; } debug!("handler.shell_request {:?}", channel_num); handler.shell_request(channel_num, self).await } - b"auth-agent-req@openssh.com" => { + "auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); + let _ = chan + .send(ChannelMsg::AgentForward { want_reply: true }) + .await; } debug!("handler.agent_request {:?}", channel_num); - let response; - (handler, response, self) = - handler.agent_request(channel_num, self).await?; + + let response = handler.agent_request(channel_num, self).await?; if response { self.request_success() } else { self.request_failure() } - Ok((handler, self)) + Ok(()) } - b"exec" => { - let req = r.read_string().map_err(crate::Error::from)?; + "exec" => { + let req = map_err!(Bytes::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Exec { - want_reply: true, - command: req.into(), - }); + let _ = chan + .send(ChannelMsg::Exec { + want_reply: true, + command: req.to_vec(), + }) + .await; } debug!("handler.exec_request {:?}", channel_num); - handler.exec_request(channel_num, req, self).await + handler.exec_request(channel_num, &req, self).await } - b"subsystem" => { - let name = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + "subsystem" => { + let name = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestSubsystem { - want_reply: true, - name: name.into(), - }); + let _ = chan + .send(ChannelMsg::RequestSubsystem { + want_reply: true, + name: name.clone(), + }) + .await; } debug!("handler.subsystem_request {:?}", channel_num); - handler.subsystem_request(channel_num, name, self).await + handler.subsystem_request(channel_num, &name, self).await } - b"window-change" => { - let col_width = r.read_u32().map_err(crate::Error::from)?; - let row_height = r.read_u32().map_err(crate::Error::from)?; - let pix_width = r.read_u32().map_err(crate::Error::from)?; - let pix_height = r.read_u32().map_err(crate::Error::from)?; + "window-change" => { + let col_width = map_err!(u32::decode(r))?; + let row_height = map_err!(u32::decode(r))?; + let pix_width = map_err!(u32::decode(r))?; + let pix_height = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::WindowChange { - col_width, - row_height, - pix_width, - pix_height, - }); + let _ = chan + .send(ChannelMsg::WindowChange { + col_width, + row_height, + pix_width, + pix_height, + }) + .await; } debug!("handler.window_change {:?}", channel_num); @@ -920,69 +918,94 @@ impl Session { ) .await } - b"signal" => { - let signal = Sig::from_name(r.read_string().map_err(crate::Error::from)?)?; + "signal" => { + let signal = Sig::from_name(&map_err!(String::decode(r))?); if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::Signal { signal: signal.clone(), }) + .await .unwrap_or(()) } debug!("handler.signal {:?} {:?}", channel_num, signal); handler.signal(channel_num, signal, self).await } x => { - warn!("unknown channel request {}", String::from_utf8_lossy(x)); - self.channel_failure(channel_num); - Ok((handler, self)) + warn!("unknown channel request {x}"); + self.channel_failure(channel_num)?; + Ok(()) } } } - Some(&msg::GLOBAL_REQUEST) => { - let mut r = buf.reader(1); - let req_type = r.read_string().map_err(crate::Error::from)?; - self.common.wants_reply = r.read_byte().map_err(crate::Error::from)? != 0; - match req_type { - b"tcpip-forward" => { - let address = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let port = r.read_u32().map_err(crate::Error::from)?; + msg::GLOBAL_REQUEST => { + let req_type = map_err!(String::decode(r))?; + self.common.wants_reply = map_err!(u8::decode(r))? != 0; + match req_type.as_str() { + "tcpip-forward" => { + let address = map_err!(String::decode(r))?; + let port = map_err!(u32::decode(r))?; debug!("handler.tcpip_forward {:?} {:?}", address, port); let mut returned_port = port; - let (h, result, mut s) = handler - .tcpip_forward(address, &mut returned_port, self) + let result = handler + .tcpip_forward(&address, &mut returned_port, self) .await?; - if let Some(ref mut enc) = s.common.encrypted { + if let Some(ref mut enc) = self.common.encrypted { if result { push_packet!(enc.write, { enc.write.push(msg::REQUEST_SUCCESS); - if s.common.wants_reply && port == 0 && returned_port != 0 { - enc.write.push_u32_be(returned_port); + if self.common.wants_reply && port == 0 && returned_port != 0 { + map_err!(returned_port.encode(&mut enc.write))?; } }) } else { push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - Ok((h, s)) + Ok(()) } - b"cancel-tcpip-forward" => { - let address = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let port = r.read_u32().map_err(crate::Error::from)?; + "cancel-tcpip-forward" => { + let address = map_err!(String::decode(r))?; + let port = map_err!(u32::decode(r))?; debug!("handler.cancel_tcpip_forward {:?} {:?}", address, port); - let (h, result, mut s) = - handler.cancel_tcpip_forward(address, port, self).await?; - if let Some(ref mut enc) = s.common.encrypted { + let result = handler.cancel_tcpip_forward(&address, port, self).await?; + if let Some(ref mut enc) = self.common.encrypted { + if result { + push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) + } else { + push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) + } + } + Ok(()) + } + "streamlocal-forward@openssh.com" => { + let server_socket_path = map_err!(String::decode(r))?; + debug!("handler.streamlocal_forward {:?}", server_socket_path); + let result = handler + .streamlocal_forward(&server_socket_path, self) + .await?; + if let Some(ref mut enc) = self.common.encrypted { if result { push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) } else { push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - Ok((h, s)) + Ok(()) + } + "cancel-streamlocal-forward@openssh.com" => { + let socket_path = map_err!(String::decode(r))?; + debug!("handler.cancel_streamlocal_forward {:?}", socket_path); + let result = handler + .cancel_streamlocal_forward(&socket_path, self) + .await?; + if let Some(ref mut enc) = self.common.encrypted { + if result { + push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) + } else { + push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) + } + } + Ok(()) } _ => { if let Some(ref mut enc) = self.common.encrypted { @@ -990,23 +1013,17 @@ impl Session { enc.write.push(msg::REQUEST_FAILURE); }); } - Ok((handler, self)) + Ok(()) } } } - Some(&msg::CHANNEL_OPEN_FAILURE) => { + msg::CHANNEL_OPEN_FAILURE => { debug!("channel_open_failure"); - let mut buf_pos = buf.reader(1); - let channel_num = ChannelId(buf_pos.read_u32().map_err(crate::Error::from)?); - let reason = - ChannelOpenFailure::from_u32(buf_pos.read_u32().map_err(crate::Error::from)?) - .unwrap_or(ChannelOpenFailure::Unknown); - let description = - std::str::from_utf8(buf_pos.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let language_tag = - std::str::from_utf8(buf_pos.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(r))?; + let reason = ChannelOpenFailure::from_u32(map_err!(u32::decode(r))?) + .unwrap_or(ChannelOpenFailure::Unknown); + let description = map_err!(String::decode(r))?; + let language_tag = map_err!(String::decode(r))?; trace!("Channel open failure description: {description}"); trace!("Channel open failure language tag: {language_tag}"); @@ -1018,25 +1035,73 @@ impl Session { if let Some(channel_sender) = self.channels.remove(&channel_num) { channel_sender .send(ChannelMsg::OpenFailure(reason)) + .await .map_err(|_| crate::Error::SendError)?; } - Ok((handler, self)) + Ok(()) + } + msg::REQUEST_SUCCESS => { + trace!("Global Request Success"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let result = if r.is_finished() { + // If a specific port was requested, the reply has no data + Some(0) + } else { + match u32::decode(r) { + Ok(port) => Some(port), + Err(e) => { + error!("Error parsing port for TcpIpForward request: {e:?}"); + None + } + } + }; + let _ = return_channel.send(result); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(true); + } + _ => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) + } + msg::REQUEST_FAILURE => { + trace!("global request failure"); + match self.open_global_requests.pop_front() { + Some(GlobalRequestResponse::Keepalive) => { + // ignore keepalives + } + Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { + let _ = return_channel.send(None); + } + Some(GlobalRequestResponse::CancelTcpIpForward(return_channel)) => { + let _ = return_channel.send(false); + } + _ => { + error!("Received global request failure for unknown request!") + } + } + Ok(()) } m => { debug!("unknown message received: {:?}", m); - Ok((handler, self)) + Ok(()) } } } - async fn server_handle_channel_open( - mut self, - handler: H, - buf: &[u8], - ) -> Result<(H, bool, Self), H::Error> { - let mut r = buf.reader(1); - let msg = OpenChannelMessage::parse(&mut r)?; + async fn server_handle_channel_open( + &mut self, + handler: &mut H, + r: &mut R, + ) -> Result { + let msg = OpenChannelMessage::parse(r)?; let sender_channel = if let Some(ref mut enc) = self.common.encrypted { enc.new_channel_id() @@ -1056,23 +1121,24 @@ impl Session { confirmed: true, wants_reply: false, pending_data: std::collections::VecDeque::new(), + pending_eof: false, + pending_close: false, }; - let (sender, receiver) = unbounded_channel(); - let channel = Channel { - id: sender_channel, - sender: self.sender.sender.clone(), - receiver, - max_packet_size: channel_params.recipient_maximum_packet_size, - window_size: channel_params.recipient_window_size, - }; + let (channel, reference) = Channel::new( + sender_channel, + self.sender.sender.clone(), + channel_params.recipient_maximum_packet_size, + channel_params.recipient_window_size, + self.common.config.channel_buffer_size, + ); match &msg.typ { ChannelType::Session => { let mut result = handler.channel_open_session(channel, self).await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1083,9 +1149,9 @@ impl Session { let mut result = handler .channel_open_x11(channel, originator_address, *originator_port, self) .await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1100,9 +1166,9 @@ impl Session { self, ) .await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1117,28 +1183,48 @@ impl Session { self, ) .await; - if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); - s.finalize_channel_open(&msg, channel_params, *allowed); + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } + ChannelType::DirectStreamLocal(d) => { + let mut result = handler + .channel_open_direct_streamlocal(channel, &d.socket_path, self) + .await; + if let Ok(allowed) = &mut result { + self.channels.insert(sender_channel, reference); + self.finalize_channel_open(&msg, channel_params, *allowed)?; + } + result + } + ChannelType::ForwardedStreamLocal(_) => { + if let Some(ref mut enc) = self.common.encrypted { + msg.fail( + &mut enc.write, + msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, + b"Unsupported channel type", + )?; + } + Ok(false) + } ChannelType::AgentForward => { if let Some(ref mut enc) = self.common.encrypted { msg.fail( &mut enc.write, msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, b"Unsupported channel type", - ); + )?; } - Ok((handler, false, self)) + Ok(false) } ChannelType::Unknown { typ } => { - debug!("unknown channel type: {}", String::from_utf8_lossy(typ)); + debug!("unknown channel type: {typ}"); if let Some(ref mut enc) = self.common.encrypted { - msg.unknown_type(&mut enc.write); + msg.unknown_type(&mut enc.write)?; } - Ok((handler, false, self)) + Ok(false) } } } @@ -1148,7 +1234,7 @@ impl Session { open: &OpenChannelMessage, channel: ChannelParams, allowed: bool, - ) { + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if allowed { open.confirm( @@ -1156,15 +1242,16 @@ impl Session { channel.sender_channel.0, channel.sender_window_size, channel.sender_maximum_packet_size, - ); + )?; enc.channels.insert(channel.sender_channel, channel); } else { open.fail( &mut enc.write, SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, b"Rejected", - ); + )?; } } + Ok(()) } } diff --git a/russh/src/server/kex.rs b/russh/src/server/kex.rs index 07a83ecc..4116ef44 100644 --- a/russh/src/server/kex.rs +++ b/russh/src/server/kex.rs @@ -1,135 +1,366 @@ +use core::fmt; use std::cell::RefCell; -use russh_keys::encoding::{Encoding, Reader}; +use client::GexParams; use log::debug; +use num_bigint::BigUint; +use ssh_encoding::Encode; +use ssh_key::Algorithm; use super::*; -use crate::cipher::SealingKey; -use crate::kex::KEXES; -use crate::key::PubKey; -use crate::negotiation::Select; +use crate::helpers::sign_with_hash_alg; +use crate::kex::dh::biguint_to_mpint; +use crate::kex::{KexAlgorithm, KexAlgorithmImplementor, KexCause, KEXES}; +use crate::keys::key::PrivateKeyWithHashAlg; +use crate::negotiation::{is_key_compatible_with_algo, Names, Select}; use crate::{msg, negotiation}; thread_local! { static HASH_BUF: RefCell = RefCell::new(CryptoVec::new()); } -impl KexInit { - pub fn server_parse( - mut self, - config: &Config, - cipher: &mut dyn SealingKey, - buf: &[u8], - write_buffer: &mut SSHBuffer, - ) -> Result { - if buf.first() == Some(&msg::KEXINIT) { - let algo = { - // read algorithms from packet. - self.exchange.client_kex_init.extend(buf); - super::negotiation::Server::read_kex(buf, &config.preferred)? - }; - if !self.sent { - self.server_write(config, cipher, write_buffer)? +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum ServerKexState { + Created, + WaitingForGexRequest { + names: Names, + kex: KexAlgorithm, + }, + WaitingForDhInit { + // both KexInit and DH init sent + names: Names, + kex: KexAlgorithm, + }, + WaitingForNewKeys { + newkeys: NewKeys, + }, +} + +pub(crate) struct ServerKex { + exchange: Exchange, + cause: KexCause, + state: ServerKexState, + config: Arc, +} + +impl Debug for ServerKex { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("ClientKex"); + s.field("cause", &self.cause); + match self.state { + ServerKexState::Created => { + s.field("state", &"created"); } - let mut key = 0; - #[allow(clippy::indexing_slicing)] // length checked - while key < config.keys.len() && config.keys[key].name() != algo.key.as_ref() { - key += 1 + ServerKexState::WaitingForGexRequest { .. } => { + s.field("state", &"waiting for GEX request"); } - let next_kex = if key < config.keys.len() { - Kex::Dh(KexDh { - exchange: self.exchange, - key, - names: algo, - session_id: self.session_id, - }) - } else { - return Err(Error::UnknownKey); - }; + ServerKexState::WaitingForDhInit { .. } => { + s.field("state", &"waiting for DH reply"); + } + ServerKexState::WaitingForNewKeys { .. } => { + s.field("state", &"waiting for NEWKEYS"); + } + } + s.finish() + } +} - Ok(next_kex) - } else { - Ok(Kex::Init(self)) +impl ServerKex { + pub fn new( + config: Arc, + client_sshid: &[u8], + server_sshid: &SshId, + cause: KexCause, + ) -> Self { + let exchange = Exchange::new(client_sshid, server_sshid.as_kex_hash_bytes()); + Self { + config, + exchange, + cause, + state: ServerKexState::Created, } } - pub fn server_write( - &mut self, - config: &Config, - cipher: &mut dyn SealingKey, - write_buffer: &mut SSHBuffer, - ) -> Result<(), Error> { - self.exchange.server_kex_init.clear(); - negotiation::write_kex(&config.preferred, &mut self.exchange.server_kex_init, true)?; - debug!("server kex init: {:?}", &self.exchange.server_kex_init[..]); - self.sent = true; - cipher.write(&self.exchange.server_kex_init, write_buffer); + pub fn kexinit(&mut self, output: &mut PacketWriter) -> Result<(), Error> { + self.exchange.server_kex_init = + negotiation::write_kex(&self.config.preferred, output, Some(self.config.as_ref()))?; + Ok(()) } -} -impl KexDh { - pub fn parse( + pub async fn step( mut self, - config: &Config, - cipher: &mut dyn SealingKey, - buf: &[u8], - write_buffer: &mut SSHBuffer, - ) -> Result { - if self.names.ignore_guessed { - // If we need to ignore this packet. - self.names.ignore_guessed = false; - Ok(Kex::Dh(self)) - } else { - // Else, process it. - assert!(buf.first() == Some(&msg::KEX_ECDH_INIT)); - let mut r = buf.reader(1); - self.exchange.client_ephemeral.extend(r.read_string()?); - - let mut kex = KEXES.get(&self.names.kex).ok_or(Error::UnknownAlgo)?.make(); - - kex.server_dh(&mut self.exchange, buf)?; - - // Then, we fill the write buffer right away, so that we - // can output it immediately when the time comes. - let kexdhdone = KexDhDone { - exchange: self.exchange, - kex, - key: self.key, - names: self.names, - session_id: self.session_id, - }; - #[allow(clippy::indexing_slicing)] // key index checked - let hash: Result<_, Error> = HASH_BUF.with(|buffer| { - let mut buffer = buffer.borrow_mut(); - buffer.clear(); - debug!("server kexdhdone.exchange = {:?}", kexdhdone.exchange); - - let mut pubkey_vec = CryptoVec::new(); - config.keys[kexdhdone.key].push_to(&mut pubkey_vec); - - let hash = kexdhdone.kex.compute_exchange_hash( - &pubkey_vec, - &kexdhdone.exchange, - &mut buffer, - )?; - debug!("exchange hash: {:?}", hash); - buffer.clear(); - buffer.push(msg::KEX_ECDH_REPLY); - config.keys[kexdhdone.key].push_to(&mut buffer); - // Server ephemeral - buffer.extend_ssh_string(&kexdhdone.exchange.server_ephemeral); + input: Option<&mut IncomingSshPacket>, + output: &mut PacketWriter, + handler: &mut H, + ) -> Result, H::Error> { + match self.state { + ServerKexState::Created => { + let Some(input) = input else { + return Err(Error::KexInit)?; + }; + if input.buffer.first() != Some(&msg::KEXINIT) { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::KexInit)?; + } + + let names = { + self.exchange.client_kex_init.extend(&input.buffer); + negotiation::Server::read_kex( + &input.buffer, + &self.config.preferred, + Some(&self.config.keys), + )? + }; + debug!("negotiated: {names:?}"); + + // seqno has already been incremented after read() + if !self.cause.is_rekey() && self.cause.is_strict_kex(&names) && input.seqn.0 != 1 { + return Err(strict_kex_violation( + msg::KEXINIT, + input.seqn.0 as usize - 1, + ))?; + } + + let kex = KEXES.get(&names.kex).ok_or(Error::UnknownAlgo)?.make(); + + if kex.skip_exchange() { + let newkeys = compute_keys( + CryptoVec::new(), + kex, + names.clone(), + self.exchange.clone(), + self.cause.session_id(), + )?; + + output.packet(|w| { + msg::NEWKEYS.encode(w)?; + Ok(()) + })?; + + return Ok(KexProgress::Done { + newkeys, + server_host_key: None, + }); + } + + if kex.is_dh_gex() { + self.state = ServerKexState::WaitingForGexRequest { names, kex }; + } else { + self.state = ServerKexState::WaitingForDhInit { names, kex }; + } + + Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn: false, + }) + } + ServerKexState::WaitingForGexRequest { names, mut kex } => { + let Some(input) = input else { + return Err(Error::KexInit)?; + }; + if input.buffer.first() != Some(&msg::KEX_DH_GEX_REQUEST) { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::KexInit)?; + } + + #[allow(clippy::indexing_slicing)] // length checked + let gex_params = GexParams::decode(&mut &input.buffer[1..])?; + debug!("client requests a gex group: {:?}", gex_params); + + let Some(dh_group) = handler.lookup_dh_gex_group(&gex_params).await? else { + debug!("server::Handler impl did not find a matching DH group (is lookup_dh_gex_group implemented?)"); + return Err(Error::Kex)?; + }; + + let prime = biguint_to_mpint(&BigUint::from_bytes_be(&dh_group.prime)); + let generator = biguint_to_mpint(&BigUint::from_bytes_be(&dh_group.generator)); + + self.exchange.gex = Some((gex_params, dh_group.clone())); + kex.dh_gex_set_group(dh_group)?; + + output.packet(|w| { + msg::KEX_DH_GEX_GROUP.encode(w)?; + prime.encode(w)?; + generator.encode(w)?; + Ok(()) + })?; + + self.state = ServerKexState::WaitingForDhInit { names, kex }; + + Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn: false, + }) + } + ServerKexState::WaitingForDhInit { mut names, mut kex } => { + let Some(input) = input else { + return Err(Error::KexInit)?; + }; + + if names.ignore_guessed { + // Ignore the next packet if (1) it follows and (2) it's not the correct guess. + debug!("ignoring guessed kex"); + names.ignore_guessed = false; + self.state = ServerKexState::WaitingForDhInit { names, kex }; + return Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn: false, + }); + } + + if input.buffer.first() + != Some(match kex.is_dh_gex() { + true => &msg::KEX_DH_GEX_INIT, + false => &msg::KEX_ECDH_INIT, + }) + { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::KexInit)?; + } + + #[allow(clippy::indexing_slicing)] // length checked + let mut r = &input.buffer[1..]; + + self.exchange + .client_ephemeral + .extend(&Bytes::decode(&mut r).map_err(Into::into)?); + + let exchange = &mut self.exchange; + kex.server_dh(exchange, &input.buffer)?; + + let Some(matching_key_index) = self + .config + .keys + .iter() + .position(|key| is_key_compatible_with_algo(key, &names.key)) + else { + debug!("we don't have a host key of type {:?}", names.key); + return Err(Error::UnknownKey.into()); + }; + + // Look up the key we'll be using to sign the exchange hash + #[allow(clippy::indexing_slicing)] // key index checked + let key = &self.config.keys[matching_key_index]; + let signature_hash_alg = match &names.key { + Algorithm::Rsa { hash } => *hash, + _ => None, + }; + + let hash = HASH_BUF.with(|buffer| { + let mut buffer = buffer.borrow_mut(); + buffer.clear(); + + let mut pubkey_vec = CryptoVec::new(); + key.public_key().to_bytes()?.encode(&mut pubkey_vec)?; + + let hash = kex.compute_exchange_hash(&pubkey_vec, exchange, &mut buffer)?; + + Ok::<_, Error>(hash) + })?; + // Hash signature - debug!("signing with key {:?}", kexdhdone.key); - debug!("hash: {:?}", hash); - debug!("key: {:?}", config.keys[kexdhdone.key]); - config.keys[kexdhdone.key].add_signature(&mut buffer, &hash)?; - cipher.write(&buffer, write_buffer); - cipher.write(&[msg::NEWKEYS], write_buffer); - Ok(hash) - }); - - Ok(Kex::Keys(kexdhdone.compute_keys(hash?, true)?)) + debug!("signing with key {:?}", key); + let signature = sign_with_hash_alg( + &PrivateKeyWithHashAlg::new(Arc::new(key.clone()), signature_hash_alg), + &hash, + ) + .map_err(Into::into)?; + + output.packet(|w| { + match kex.is_dh_gex() { + true => &msg::KEX_DH_GEX_REPLY, + false => &msg::KEX_ECDH_REPLY, + } + .encode(w)?; + key.public_key().to_bytes()?.encode(w)?; + exchange.server_ephemeral.encode(w)?; + signature.encode(w)?; + Ok(()) + })?; + + output.packet(|w| { + msg::NEWKEYS.encode(w)?; + Ok(()) + })?; + + let newkeys = compute_keys( + hash, + kex, + names.clone(), + self.exchange.clone(), + self.cause.session_id(), + )?; + + let reset_seqn = self.cause.is_strict_kex(&newkeys.names); + + self.state = ServerKexState::WaitingForNewKeys { newkeys }; + + Ok(KexProgress::NeedsReply { + kex: self, + reset_seqn, + }) + } + ServerKexState::WaitingForNewKeys { newkeys } => { + let Some(input) = input else { + return Err(Error::KexInit.into()); + }; + + if input.buffer.first() != Some(&msg::NEWKEYS) { + error!( + "Unexpected kex message at this stage: {:?}", + input.buffer.first() + ); + return Err(Error::Kex.into()); + } + + debug!("new keys received"); + Ok(KexProgress::Done { + newkeys, + server_host_key: None, + }) + } } } } + +fn compute_keys( + hash: CryptoVec, + kex: KexAlgorithm, + names: Names, + exchange: Exchange, + session_id: Option<&CryptoVec>, +) -> Result { + let session_id = if let Some(session_id) = session_id { + session_id + } else { + &hash + }; + // Now computing keys. + let c = kex.compute_keys( + session_id, + &hash, + names.cipher, + names.client_mac, + names.server_mac, + true, + )?; + Ok(NewKeys { + exchange, + names, + kex, + key: 0, + cipher: c, + session_id: session_id.clone(), + }) +} diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index 7f07cae8..b0b4f155 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -16,136 +16,56 @@ //! # Writing servers //! //! There are two ways of accepting connections: -//! * implement the [Server](server::Server) trait and let [run](server::run) handle everything +//! * implement the [Server](server::Server) trait and let [run_on_socket](server::Server::run_on_socket)/[run_on_address](server::Server::run_on_address) handle everything //! * accept connections yourself and pass them to [run_stream](server::run_stream) //! //! In both cases, you'll first need to implement the [Handler](server::Handler) trait - //! this is where you'll handle various events. //! -//! Here is an example server, which forwards input from each client -//! to all other clients: +//! Check out the following examples: //! -//! ``` -//! use async_trait::async_trait; -//! use std::sync::{Mutex, Arc}; -//! use russh::*; -//! use russh::server::{Auth, Session, Msg}; -//! use russh_keys::*; -//! use std::collections::HashMap; -//! use futures::Future; -//! -//! #[tokio::main] -//! async fn main() { -//! let client_key = russh_keys::key::KeyPair::generate_ed25519().unwrap(); -//! let client_pubkey = Arc::new(client_key.clone_public_key().unwrap()); -//! let mut config = russh::server::Config::default(); -//! config.inactivity_timeout = Some(std::time::Duration::from_secs(3)); -//! config.auth_rejection_time = std::time::Duration::from_secs(3); -//! config.keys.push(russh_keys::key::KeyPair::generate_ed25519().unwrap()); -//! let config = Arc::new(config); -//! let sh = Server{ -//! client_pubkey, -//! clients: Arc::new(Mutex::new(HashMap::new())), -//! id: 0 -//! }; -//! tokio::time::timeout( -//! std::time::Duration::from_secs(1), -//! russh::server::run(config, ("0.0.0.0", 2222), sh) -//! ).await.unwrap_or(Ok(())); -//! } -//! -//! #[derive(Clone)] -//! struct Server { -//! client_pubkey: Arc, -//! clients: Arc>>>, -//! id: usize, -//! } -//! -//! impl server::Server for Server { -//! type Handler = Self; -//! fn new_client(&mut self, _: Option) -> Self { -//! let s = self.clone(); -//! self.id += 1; -//! s -//! } -//! } -//! -//! #[async_trait] -//! impl server::Handler for Server { -//! type Error = anyhow::Error; -//! -//! async fn channel_open_session(self, channel: Channel, session: Session) -> Result<(Self, bool, Session), Self::Error> { -//! { -//! let mut clients = self.clients.lock().unwrap(); -//! clients.insert((self.id, channel.id()), channel); -//! } -//! Ok((self, true, session)) -//! } -//! async fn auth_publickey(self, _: &str, _: &key::PublicKey) -> Result<(Self, Auth), Self::Error> { -//! Ok((self, server::Auth::Accept)) -//! } -//! async fn data(self, channel: ChannelId, data: &[u8], mut session: Session) -> Result<(Self, Session), Self::Error> { -//! { -//! let mut clients = self.clients.lock().unwrap(); -//! for ((id, _channel_id), ref mut channel) in clients.iter_mut() { -//! channel.data(data); -//! } -//! } -//! Ok((self, session)) -//! } -//! } -//! ``` -//! -//! Note the call to `session.handle()`, which allows to keep a handle -//! to a client outside the event loop. This feature is internally -//! implemented using `futures::sync::mpsc` channels. -//! -//! Note that this is just a toy server. In particular: -//! -//! - It doesn't handle errors when `s.data` returns an error, i.e. when the -//! client has disappeared -//! -//! - Each new connection increments the `id` field. Even though we -//! would need a lot of connections per second for a very long time to -//! saturate it, there are probably better ways to handle this to -//! avoid collisions. +//! * [Server that forwards your input to all connected clients](https://github.com/warp-tech/russh/blob/main/russh/examples/echoserver.rs) +//! * [Server handing channel processing off to a library (here, `russh-sftp`)](https://github.com/warp-tech/russh/blob/main/russh/examples/sftp_server.rs) +//! * Serving `ratatui` based TUI app to clients: [per-client](https://github.com/warp-tech/russh/blob/main/russh/examples/ratatui_app.rs), [shared](https://github.com/warp-tech/russh/blob/main/russh/examples/ratatui_shared_app.rs) use std; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::error; -use async_trait::async_trait; +use bytes::Bytes; +use client::GexParams; use futures::future::Future; -use russh_keys::key; +use log::{debug, error, info, warn}; +use msg::{is_kex_msg, validate_client_msg_strict_kex}; +use russh_util::runtime::JoinHandle; +use russh_util::time::Instant; +use ssh_key::{Certificate, PrivateKey}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::pin; -use tokio::task::JoinHandle; -use crate::cipher::{clear, CipherPair, OpeningKey}; +use crate::cipher::{clear, OpeningKey}; +use crate::kex::dh::groups::{DhGroup, BUILTIN_SAFE_DH_GROUPS, DH_GROUP14}; +use crate::kex::{KexProgress, SessionKexState}; use crate::session::*; use crate::ssh_read::*; use crate::sshbuffer::*; -use crate::*; +use crate::{map_err, *}; mod kex; mod session; -pub use self::kex::*; pub use self::session::*; mod encrypted; -#[derive(Debug)] /// Configuration of a server. pub struct Config { /// The server ID string sent at the beginning of the protocol. pub server_id: SshId, /// Authentication methods proposed to the client. pub methods: auth::MethodSet, - /// The authentication banner, usually a warning message shown to the client. - pub auth_banner: Option<&'static str>, /// Authentication rejections must happen in constant time for /// security reasons. Russh does not handle this by default. pub auth_rejection_time: std::time::Duration, @@ -153,13 +73,15 @@ pub struct Config { /// OpenSSH clients will send an initial "none" auth to probe for authentication methods. pub auth_rejection_time_initial: Option, /// The server's keys. The first key pair in the client's preference order will be chosen. - pub keys: Vec, + pub keys: Vec, /// The bytes and time limits before key re-exchange. pub limits: Limits, /// The initial size of a channel (used for flow control). pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) + pub channel_buffer_size: usize, /// Internal event buffer size pub event_buffer_size: usize, /// Lists of preferred algorithms. @@ -168,6 +90,12 @@ pub struct Config { pub max_auth_attempts: usize, /// Time after which the connection is garbage-collected. pub inactivity_timeout: Option, + /// If nothing is received from the client for this amount of time, send a keepalive message. + pub keepalive_interval: Option, + /// If this many keepalives have been sent without reply, close the connection. + pub keepalive_max: usize, + /// If active, invoke `set_nodelay(true)` on client sockets; disabled by default (i.e. Nagle's algorithm is active). + pub nodelay: bool, } impl Default for Config { @@ -179,39 +107,59 @@ impl Default for Config { env!("CARGO_PKG_VERSION") )), methods: auth::MethodSet::all(), - auth_banner: None, auth_rejection_time: std::time::Duration::from_secs(1), auth_rejection_time_initial: None, keys: Vec::new(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, event_buffer_size: 10, limits: Limits::default(), preferred: Default::default(), max_auth_attempts: 10, inactivity_timeout: Some(std::time::Duration::from_secs(600)), + keepalive_interval: None, + keepalive_max: 3, + nodelay: false, } } } +impl Debug for Config { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // display everything except the private keys + f.debug_struct("Config") + .field("server_id", &self.server_id) + .field("methods", &self.methods) + .field("auth_rejection_time", &self.auth_rejection_time) + .field( + "auth_rejection_time_initial", + &self.auth_rejection_time_initial, + ) + .field("keys", &"***") + .field("window_size", &self.window_size) + .field("maximum_packet_size", &self.maximum_packet_size) + .field("channel_buffer_size", &self.channel_buffer_size) + .field("event_buffer_size", &self.event_buffer_size) + .field("limits", &self.limits) + .field("preferred", &self.preferred) + .field("max_auth_attempts", &self.max_auth_attempts) + .field("inactivity_timeout", &self.inactivity_timeout) + .field("keepalive_interval", &self.keepalive_interval) + .field("keepalive_max", &self.keepalive_max) + .finish() + } +} + /// A client's response in a challenge-response authentication. /// /// You should iterate it to get `&[u8]` response slices. -#[derive(Debug)] -pub struct Response<'a> { - pos: russh_keys::encoding::Position<'a>, - n: u32, -} +pub struct Response<'a>(&'a mut (dyn Iterator> + Send)); -impl<'a> Iterator for Response<'a> { - type Item = &'a [u8]; +impl Iterator for Response<'_> { + type Item = Bytes; fn next(&mut self) -> Option { - if self.n == 0 { - None - } else { - self.n -= 1; - self.pos.read_string().ok() - } + self.0.next().flatten() } } @@ -222,6 +170,7 @@ pub enum Auth { /// Reject the authentication request. Reject { proceed_with_methods: Option, + partial_success: bool, }, /// Accept the authentication request. Accept, @@ -243,10 +192,20 @@ pub enum Auth { }, } +impl Auth { + pub fn reject() -> Self { + Auth::Reject { + proceed_with_methods: None, + partial_success: false, + } + } +} + /// Server handler. Each client will have their own handler. /// -/// Note: this is an `async_trait`. Click `[source]` on the right to see actual async function definitions. -#[async_trait] +/// Note: this is an async trait. The trait functions return `impl Future`, +/// and you can simply define them as `async fn` instead. +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] pub trait Handler: Sized { type Error: From + Send; @@ -254,13 +213,8 @@ pub trait Handler: Sized { /// sure rejection happens in time `config.auth_rejection_time`, /// except if this method takes more than that. #[allow(unused_variables)] - async fn auth_none(self, user: &str) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + fn auth_none(&mut self, user: &str) -> impl Future> + Send { + async { Ok(Auth::reject()) } } /// Check authentication using the "password" method. Russh @@ -268,13 +222,12 @@ pub trait Handler: Sized { /// `config.auth_rejection_time`, except if this method takes more /// than that. #[allow(unused_variables)] - async fn auth_password(self, user: &str, password: &str) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + fn auth_password( + &mut self, + user: &str, + password: &str, + ) -> impl Future> + Send { + async { Ok(Auth::reject()) } } /// Check authentication using the "publickey" method. This method @@ -285,17 +238,42 @@ pub trait Handler: Sized { /// `config.auth_rejection_time`, except if this method takes more /// time than that. #[allow(unused_variables)] - async fn auth_publickey( - self, + fn auth_publickey_offered( + &mut self, user: &str, - public_key: &key::PublicKey, - ) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + public_key: &ssh_key::PublicKey, + ) -> impl Future> + Send { + async { Ok(Auth::Accept) } + } + + /// Check authentication using the "publickey" method. This method + /// is called after the signature has been verified and key + /// ownership has been confirmed. + /// Russh guarantees that rejection happens in constant time + /// `config.auth_rejection_time`, except if this method takes more + /// time than that. + #[allow(unused_variables)] + fn auth_publickey( + &mut self, + user: &str, + public_key: &ssh_key::PublicKey, + ) -> impl Future> + Send { + async { Ok(Auth::reject()) } + } + + /// Check authentication using an OpenSSH certificate. This method + /// is called after the signature has been verified and key + /// ownership has been confirmed. + /// Russh guarantees that rejection happens in constant time + /// `config.auth_rejection_time`, except if this method takes more + /// time than that. + #[allow(unused_variables)] + fn auth_openssh_certificate( + &mut self, + user: &str, + certificate: &Certificate, + ) -> impl Future> + Send { + async { Ok(Auth::reject()) } } /// Check authentication using the "keyboard-interactive" @@ -303,124 +281,143 @@ pub trait Handler: Sized { /// `config.auth_rejection_time`, except if this method takes more /// than that. #[allow(unused_variables)] - async fn auth_keyboard_interactive( - self, + fn auth_keyboard_interactive<'a>( + &'a mut self, user: &str, submethods: &str, - response: Option>, - ) -> Result<(Self, Auth), Self::Error> { - Ok(( - self, - Auth::Reject { - proceed_with_methods: None, - }, - )) + response: Option>, + ) -> impl Future> + Send { + async { Ok(Auth::reject()) } } /// Called when authentication succeeds for a session. #[allow(unused_variables)] - async fn auth_succeeded(self, session: Session) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + fn auth_succeeded( + &mut self, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } + } + + /// Called when authentication starts but before it is successful. + /// Return value is an authentication banner, usually a warning message shown to the client. + #[allow(unused_variables)] + fn authentication_banner( + &mut self, + ) -> impl Future, Self::Error>> + Send { + async { Ok(None) } } /// Called when the client closes a channel. #[allow(unused_variables)] - async fn channel_close( - self, + fn channel_close( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the client sends EOF to a channel. #[allow(unused_variables)] - async fn channel_eof( - self, + fn channel_eof( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when a new session channel is created. /// Return value indicates whether the channel request should be granted. #[allow(unused_variables)] - async fn channel_open_session( - self, + fn channel_open_session( + &mut self, channel: Channel, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } } /// Called when a new X11 channel is created. /// Return value indicates whether the channel request should be granted. #[allow(unused_variables)] - async fn channel_open_x11( - self, + fn channel_open_x11( + &mut self, channel: Channel, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } } - /// Called when a new TCP/IP is created. + /// Called when a new direct TCP/IP ("local TCP forwarding") channel is opened. /// Return value indicates whether the channel request should be granted. #[allow(unused_variables)] - async fn channel_open_direct_tcpip( - self, + fn channel_open_direct_tcpip( + &mut self, channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } } - /// Called when a new forwarded connection comes in. + /// Called when a new remote forwarded TCP connection comes in. /// #[allow(unused_variables)] - async fn channel_open_forwarded_tcpip( - self, + fn channel_open_forwarded_tcpip( + &mut self, channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } + } + + /// Called when a new direct-streamlocal ("local UNIX socket forwarding") channel is created. + /// Return value indicates whether the channel request should be granted. + #[allow(unused_variables)] + fn channel_open_direct_streamlocal( + &mut self, + channel: Channel, + socket_path: &str, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } } /// Called when the client confirmed our request to open a /// channel. A channel can only be written to after receiving this /// message (this library panics otherwise). #[allow(unused_variables)] - async fn channel_open_confirmation( - self, + fn channel_open_confirmation( + &mut self, id: ChannelId, max_packet_size: u32, window_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when a data packet is received. A response can be /// written to the `response` argument. #[allow(unused_variables)] - async fn data( - self, + fn data( + &mut self, channel: ChannelId, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when an extended data packet is received. Code 1 means @@ -428,26 +425,26 @@ pub trait Handler: Sized { /// defined (see /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2)). #[allow(unused_variables)] - async fn extended_data( - self, + fn extended_data( + &mut self, channel: ChannelId, code: u32, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when the network window is adjusted, meaning that we /// can send more bytes. #[allow(unused_variables)] - async fn window_adjusted( - self, + fn window_adjusted( + &mut self, channel: ChannelId, new_size: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Called when this server adjusts the network window. Return the @@ -459,9 +456,30 @@ pub trait Handler: Sized { /// The client requests a pseudo-terminal with the given /// specifications. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn pty_request( + /// &mut self, + /// channel: ChannelId, + /// term: &str, + /// col_width: u32, + /// row_height: u32, + /// pix_width: u32, + /// pix_height: u32, + /// modes: &[(Pty, u32)], + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables, clippy::too_many_arguments)] - async fn pty_request( - self, + fn pty_request( + &mut self, channel: ChannelId, term: &str, col_width: u32, @@ -469,164 +487,406 @@ pub trait Handler: Sized { pix_width: u32, pix_height: u32, modes: &[(Pty, u32)], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client requests an X11 connection. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn x11_request( + /// &mut self, + /// channel: ChannelId, + /// single_connection: bool, + /// x11_auth_protocol: &str, + /// x11_auth_cookie: &str, + /// x11_screen_number: u32, + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn x11_request( - self, + fn x11_request( + &mut self, channel: ChannelId, single_connection: bool, x11_auth_protocol: &str, x11_auth_cookie: &str, x11_screen_number: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client wants to set the given environment variable. Check /// these carefully, as it is dangerous to allow any variable /// environment to be set. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn env_request( + /// &mut self, + /// channel: ChannelId, + /// variable_name: &str, + /// variable_value: &str, + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn env_request( - self, + fn env_request( + &mut self, channel: ChannelId, variable_name: &str, variable_value: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client requests a shell. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn shell_request( + /// &mut self, + /// channel: ChannelId, + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn shell_request( - self, + fn shell_request( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client sends a command to execute, to be passed to a /// shell. Make sure to check the command before doing so. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn exec_request( + /// &mut self, + /// channel: ChannelId, + /// data: &[u8], + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn exec_request( - self, + fn exec_request( + &mut self, channel: ChannelId, data: &[u8], - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client asks to start the subsystem with the given name /// (such as sftp). + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn subsystem_request( + /// &mut self, + /// channel: ChannelId, + /// name: &str, + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn subsystem_request( - self, + fn subsystem_request( + &mut self, channel: ChannelId, name: &str, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client's pseudo-terminal window size has changed. + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn window_change_request( + /// &mut self, + /// channel: ChannelId, + /// col_width: u32, + /// row_height: u32, + /// pix_width: u32, + /// pix_height: u32, + /// session: &mut Session, + /// ) -> Result<(), Self::Error> { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn window_change_request( - self, + fn window_change_request( + &mut self, channel: ChannelId, col_width: u32, row_height: u32, pix_width: u32, pix_height: u32, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// The client requests OpenSSH agent forwarding + /// + /// **Note:** Success or failure should be communicated to the client by calling + /// `session.channel_success(channel)` or `session.channel_failure(channel)` respectively. For + /// instance: + /// + /// ```ignore + /// async fn agent_request( + /// &mut self, + /// channel: ChannelId, + /// session: &mut Session, + /// ) -> Result { + /// session.channel_success(channel); + /// Ok(()) + /// } + /// ``` #[allow(unused_variables)] - async fn agent_request( - self, + fn agent_request( + &mut self, channel: ChannelId, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } } /// The client is sending a signal (usually to pass to the /// currently running process). #[allow(unused_variables)] - async fn signal( - self, + fn signal( + &mut self, channel: ChannelId, signal: Sig, - session: Session, - ) -> Result<(Self, Session), Self::Error> { - Ok((self, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } } /// Used for reverse-forwarding ports, see /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-7). /// If `port` is 0, you should set it to the allocated port number. #[allow(unused_variables)] - async fn tcpip_forward( - self, + fn tcpip_forward( + &mut self, address: &str, port: &mut u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } } /// Used to stop the reverse-forwarding of a port, see /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-7). #[allow(unused_variables)] - async fn cancel_tcpip_forward( - self, + fn cancel_tcpip_forward( + &mut self, address: &str, port: u32, - session: Session, - ) -> Result<(Self, bool, Session), Self::Error> { - Ok((self, false, session)) + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } + } + + #[allow(unused_variables)] + fn streamlocal_forward( + &mut self, + socket_path: &str, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } + } + + #[allow(unused_variables)] + fn cancel_streamlocal_forward( + &mut self, + socket_path: &str, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(false) } + } + + /// Override when enabling the `diffie-hellman-group-exchange-*` key exchange methods. + /// Should return a Diffie-Hellman group with a safe prime whose length is + /// between `gex_params.min_group_size` and `gex_params.max_group_size` and + /// (if possible) over and as close as possible to `gex_params.preferred_group_size`. + /// + /// OpenSSH uses a pre-generated database of safe primes stored in `/etc/ssh/moduli` + /// + /// The default implementation picks a group from a very short static list + /// of built-in standard groups and is not really taking advantage of the security + /// offered by these kex methods. + /// + /// See https://datatracker.ietf.org/doc/html/rfc4419#section-3 + #[allow(unused_variables)] + fn lookup_dh_gex_group( + &mut self, + gex_params: &GexParams, + ) -> impl Future, Self::Error>> + Send { + async { + let mut best_group = &DH_GROUP14; + + // Find _some_ matching group + for group in BUILTIN_SAFE_DH_GROUPS.iter() { + if group.bit_size() >= gex_params.min_group_size() + && group.bit_size() <= gex_params.max_group_size() + { + best_group = *group; + break; + } + } + + // Find _closest_ matching group + for group in BUILTIN_SAFE_DH_GROUPS.iter() { + if group.bit_size() > gex_params.preferred_group_size() { + best_group = *group; + break; + } + } + + Ok(Some(best_group.clone())) + } } } +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] /// Trait used to create new handlers when clients connect. pub trait Server { /// The type of handlers. - type Handler: Handler + Send; + type Handler: Handler + Send + 'static; /// Called when a new client connects. fn new_client(&mut self, peer_addr: Option) -> Self::Handler; -} + /// Called when an active connection fails. + fn handle_session_error(&mut self, _error: ::Error) {} -/// Run a server. -/// Create a new `Connection` from the server's configuration, a -/// stream and a [`Handler`](trait.Handler.html). -pub async fn run( - config: Arc, - addrs: A, - mut server: H, -) -> Result<(), std::io::Error> { - let socket = TcpListener::bind(addrs).await?; - if config.maximum_packet_size > 65535 { - error!( - "Maximum packet size ({:?}) should not larger than a TCP packet (65535)", - config.maximum_packet_size - ); + /// Run a server on a specified `tokio::net::TcpListener`. Useful when dropping + /// privileges immediately after socket binding, for example. + fn run_on_socket( + &mut self, + config: Arc, + socket: &TcpListener, + ) -> impl Future> + Send + where + Self: Send, + { + async move { + if config.maximum_packet_size > 65535 { + error!( + "Maximum packet size ({:?}) should not larger than a TCP packet (65535)", + config.maximum_packet_size + ); + } + + let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel(); + + loop { + tokio::select! { + accept_result = socket.accept() => { + match accept_result { + Ok((socket, _)) => { + let config = config.clone(); + let handler = self.new_client(socket.peer_addr().ok()); + let error_tx = error_tx.clone(); + + russh_util::runtime::spawn(async move { + if config.nodelay { + if let Err(e) = socket.set_nodelay(true) { + warn!("set_nodelay() failed: {e:?}"); + } + } + + let session = match run_stream(config, socket, handler).await { + Ok(s) => s, + Err(e) => { + debug!("Connection setup failed"); + let _ = error_tx.send(e); + return + } + }; + + match session.await { + Ok(_) => debug!("Connection closed"), + Err(e) => { + debug!("Connection closed with error"); + let _ = error_tx.send(e); + } + } + }); + } + Err(e) => { + return Err(e); + } + } + }, + + Some(error) = error_rx.recv() => { + self.handle_session_error(error); + } + } + } + } } - while let Ok((socket, _)) = socket.accept().await { - let config = config.clone(); - let server = server.new_client(socket.peer_addr().ok()); - tokio::spawn(run_stream(config, socket, server)); + + /// Run a server. + /// Create a new `Connection` from the server's configuration, a + /// stream and a [`Handler`](trait.Handler.html). + fn run_on_address( + &mut self, + config: Arc, + addrs: A, + ) -> impl Future> + Send + where + Self: Send, + { + async move { + let socket = TcpListener::bind(addrs).await?; + self.run_on_socket(config, &socket).await + } } - Ok(()) } use std::cell::RefCell; @@ -635,14 +895,6 @@ thread_local! { static B2: RefCell = RefCell::new(CryptoVec::new()); } -pub(crate) async fn timeout(delay: Option) { - if let Some(delay) = delay { - tokio::time::sleep(delay).await - } else { - futures::future::pending().await - }; -} - async fn start_reading( mut stream_read: R, mut buffer: SSHBuffer, @@ -655,7 +907,7 @@ async fn start_reading( /// An active server session returned by [run_stream]. /// -/// Implements [Future] and needs to be awaited to allow the session to run. +/// Implements [Future] and can be awaited to wait for the session to finish. pub struct RunningSession { handle: Handle, join: JoinHandle>, @@ -682,7 +934,7 @@ impl Future for RunningSession { } } -/// Run a single connection to completion. +/// Start a single connection in the background. pub async fn run_stream( config: Arc, mut stream: R, @@ -695,17 +947,18 @@ where // Writing SSH id. let mut write_buffer = SSHBuffer::new(); write_buffer.send_ssh_id(&config.as_ref().server_id); - stream - .write_all(&write_buffer.buffer[..]) - .await - .map_err(crate::Error::from)?; + map_err!(stream.write_all(&write_buffer.buffer[..]).await)?; // Reading SSH id and allocating a session. let mut stream = SshRead::new(stream); let (sender, receiver) = tokio::sync::mpsc::channel(config.event_buffer_size); + let handle = server::session::Handle { + sender, + channel_buffer_size: config.channel_buffer_size, + }; + let common = read_ssh_id(config, &mut stream).await?; - let handle = server::session::Handle { sender }; - let session = Session { + let mut session = Session { target_window_size: common.config.window_size, common, receiver, @@ -713,9 +966,13 @@ where pending_reads: Vec::new(), pending_len: 0, channels: HashMap::new(), + open_global_requests: VecDeque::new(), + kex: SessionKexState::Idle, }; - let join = tokio::spawn(session.run(stream, handler)); + session.begin_rekey()?; + + let join = russh_util::runtime::spawn(session.run(stream, handler)); Ok(RunningSession { handle, join }) } @@ -729,99 +986,121 @@ async fn read_ssh_id( } else { read.read_ssh_id().await? }; - let mut exchange = Exchange::new(); - exchange.client_id.extend(sshid); - // Preparing the response - exchange - .server_id - .extend(config.as_ref().server_id.as_kex_hash_bytes()); - let mut kexinit = KexInit { - exchange, - algo: None, - sent: false, - session_id: None, - }; - let mut cipher = CipherPair { - local_to_remote: Box::new(clear::Key), - remote_to_local: Box::new(clear::Key), - }; - let mut write_buffer = SSHBuffer::new(); - kexinit.server_write( - config.as_ref(), - &mut *cipher.local_to_remote, - &mut write_buffer, - )?; - Ok(CommonSession { - write_buffer, - kex: Some(Kex::Init(kexinit)), + + let session = CommonSession { + packet_writer: PacketWriter::clear(), + // kex: Some(Kex::Init(kexinit)), auth_user: String::new(), auth_method: None, // Client only. auth_attempts: 0, - cipher, + remote_to_local: Box::new(clear::Key), encrypted: None, config, wants_reply: false, disconnected: false, buffer: CryptoVec::new(), - }) + strict_kex: false, + alive_timeouts: 0, + received_data: false, + remote_sshid: sshid.into(), + }; + Ok(session) } async fn reply( - mut session: Session, - handler: H, - buf: &[u8], -) -> Result<(H, Session), H::Error> { - // Handle key exchange/re-exchange. - if session.common.encrypted.is_none() { - match session.common.kex.take() { - Some(Kex::Init(kexinit)) => { - if kexinit.algo.is_some() || buf.first() == Some(&msg::KEXINIT) { - session.common.kex = Some(kexinit.server_parse( - session.common.config.as_ref(), - &mut *session.common.cipher.local_to_remote, - buf, - &mut session.common.write_buffer, - )?); - return Ok((handler, session)); - } else { - // Else, i.e. if the other side has not started - // the key exchange, process its packets by simple - // not returning. - session.common.kex = Some(Kex::Init(kexinit)) + session: &mut Session, + handler: &mut H, + pkt: &mut IncomingSshPacket, +) -> Result<(), H::Error> { + if let Some(message_type) = pkt.buffer.first() { + debug!( + "< msg type {message_type:?}, seqn {:?}, len {}", + pkt.seqn.0, + pkt.buffer.len() + ); + if session.common.strict_kex && session.common.encrypted.is_none() { + let seqno = pkt.seqn.0 - 1; // was incremented after read() + validate_client_msg_strict_kex(*message_type, seqno as usize)?; + } + + if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) { + return Ok(()); + } + } + + if pkt.buffer.first() == Some(&msg::KEXINIT) && session.kex == SessionKexState::Idle { + // Not currently in a rekey but received KEXINIT + info!("Client has initiated re-key"); + session.begin_rekey()?; + // Kex will consume the packet right away + } + + let is_kex_msg = pkt.buffer.first().cloned().map(is_kex_msg).unwrap_or(false); + + if is_kex_msg { + if let SessionKexState::InProgress(kex) = session.kex.take() { + let progress = kex + .step(Some(pkt), &mut session.common.packet_writer, handler) + .await?; + + match progress { + KexProgress::NeedsReply { kex, reset_seqn } => { + debug!("kex impl continues: {kex:?}"); + session.kex = SessionKexState::InProgress(kex); + if reset_seqn { + debug!("kex impl requests seqno reset"); + session.common.reset_seqn(); + } } - } - Some(Kex::Dh(kexdh)) => { - session.common.kex = Some(kexdh.parse( - session.common.config.as_ref(), - &mut *session.common.cipher.local_to_remote, - buf, - &mut session.common.write_buffer, - )?); - return Ok((handler, session)); - } - Some(Kex::Keys(newkeys)) => { - if buf.first() != Some(&msg::NEWKEYS) { - return Err(Error::Kex.into()); + KexProgress::Done { newkeys, .. } => { + debug!("kex impl has completed"); + session.common.strict_kex = + session.common.strict_kex || newkeys.names.strict_kex; + + if let Some(ref mut enc) = session.common.encrypted { + // This is a rekey + enc.last_rekey = Instant::now(); + session.common.packet_writer.buffer().bytes = 0; + enc.flush_all_pending()?; + + let mut pending = std::mem::take(&mut session.pending_reads); + for p in pending.drain(..) { + session.process_packet(handler, &p).await?; + } + session.pending_reads = pending; + session.pending_len = 0; + session.common.newkeys(newkeys); + session.flush()?; + } else { + // This is the initial kex + + session.common.encrypted( + EncryptedState::WaitingAuthServiceRequest { + sent: false, + accepted: false, + }, + newkeys, + ); + + session.maybe_send_ext_info()?; + } + + session.kex = SessionKexState::Idle; + + if session.common.strict_kex { + pkt.seqn = Wrapping(0); + } + + debug!("kex done"); } - // Ok, NEWKEYS received, now encrypted. - session.common.encrypted( - EncryptedState::WaitingAuthServiceRequest { - sent: false, - accepted: false, - }, - newkeys, - ); - session.maybe_send_ext_info(); - return Ok((handler, session)); - } - Some(kex) => { - session.common.kex = Some(kex); - return Ok((handler, session)); } - None => {} + + session.flush()?; + + return Ok(()); } - Ok((handler, session)) - } else { - Ok(session.server_read_encrypted(handler, buf).await?) } + + // Handle key exchange/re-exchange. + session.server_read_encrypted(handler, pkt).await } diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 7cbcfeb3..4893a7e9 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -1,17 +1,23 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; +use std::io::ErrorKind; use std::sync::Arc; +use channels::WindowSizeRef; +use kex::ServerKex; use log::debug; -use russh_keys::encoding::{Encoding, Reader}; +use negotiation::parse_kex_algo_list; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::oneshot; use super::*; -use crate::channels::{Channel, ChannelMsg}; -use crate::kex::EXTENSION_SUPPORT_AS_CLIENT; -use crate::msg; +use crate::channels::{Channel, ChannelMsg, ChannelReadHalf, ChannelRef, ChannelWriteHalf}; +use crate::helpers::NameList; +use crate::kex::{KexCause, SessionKexState, EXTENSION_SUPPORT_AS_CLIENT}; +use crate::{map_err, msg}; /// A connected server session. This type is unique to a client. +#[derive(Debug)] pub struct Session { pub(crate) common: CommonSession>, pub(crate) sender: Handle, @@ -19,40 +25,63 @@ pub struct Session { pub(crate) target_window_size: u32, pub(crate) pending_reads: Vec, pub(crate) pending_len: u32, - pub(crate) channels: HashMap>, + pub(crate) channels: HashMap, + pub(crate) open_global_requests: VecDeque, + pub(crate) kex: SessionKexState, } + #[derive(Debug)] pub enum Msg { + ChannelOpenAgent { + channel_ref: ChannelRef, + }, ChannelOpenSession { - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectTcpIp { host_to_connect: String, port_to_connect: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, + }, + ChannelOpenDirectStreamLocal { + socket_path: String, + channel_ref: ChannelRef, }, ChannelOpenForwardedTcpIp { connected_address: String, connected_port: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, + }, + ChannelOpenForwardedStreamLocal { + server_socket_path: String, + channel_ref: ChannelRef, }, ChannelOpenX11 { originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, TcpIpForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>>, address: String, port: u32, }, CancelTcpIpForward { + /// Provide a channel for the reply result to request a reply from the server + reply_channel: Option>, address: String, port: u32, }, + Disconnect { + reason: crate::Disconnect, + description: String, + language_tag: String, + }, Channel(ChannelId, ChannelMsg), } @@ -62,11 +91,12 @@ impl From<(ChannelId, ChannelMsg)> for Msg { } } -#[derive(Clone)] +#[derive(Clone, Debug)] /// Handle to a session, used to send messages to a client outside of /// the request/response cycle. pub struct Handle { pub(crate) sender: Sender, + pub(crate) channel_buffer_size: usize, } impl Handle { @@ -148,19 +178,63 @@ impl Handle { } /// Notifies the client that it can open TCP/IP forwarding channels for a port. - pub async fn forward_tcpip(&self, address: String, port: u32) -> Result<(), ()> { + pub async fn forward_tcpip(&self, address: String, port: u32) -> Result { + let (reply_send, reply_recv) = oneshot::channel(); self.sender - .send(Msg::TcpIpForward { address, port }) + .send(Msg::TcpIpForward { + reply_channel: Some(reply_send), + address, + port, + }) .await - .map_err(|_| ()) + .map_err(|_| ())?; + + match reply_recv.await { + Ok(Some(port)) => Ok(port), + Ok(None) => Err(()), // crate::Error::RequestDenied + Err(e) => { + error!("Unable to receive TcpIpForward result: {e:?}"); + Err(()) // crate::Error::Disconnect + } + } } /// Notifies the client that it can no longer open TCP/IP forwarding channel for a port. pub async fn cancel_forward_tcpip(&self, address: String, port: u32) -> Result<(), ()> { + let (reply_send, reply_recv) = oneshot::channel(); self.sender - .send(Msg::CancelTcpIpForward { address, port }) + .send(Msg::CancelTcpIpForward { + reply_channel: Some(reply_send), + address, + port, + }) + .await + .map_err(|_| ())?; + match reply_recv.await { + Ok(true) => Ok(()), + Ok(false) => Err(()), // crate::Error::RequestDenied + Err(e) => { + error!("Unable to receive CancelTcpIpForward result: {e:?}"); + Err(()) // crate::Error::Disconnect + } + } + } + + /// Open an agent forwarding channel. This can be used once the client has + /// confirmed that it allows agent forwarding. See + /// [PROTOCOL.agent](https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent). + pub async fn channel_open_agent(&self) -> Result, Error> { + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + + self.sender + .send(Msg::ChannelOpenAgent { channel_ref }) + .await + .map_err(|_| Error::SendError)?; + + self.wait_channel_confirmation(receiver, window_size_ref) .await - .map_err(|_| ()) } /// Request a session channel (the most basic type of @@ -169,12 +243,17 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender - .send(Msg::ChannelOpenSession { sender }) + .send(Msg::ChannelOpenSession { channel_ref }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Open a TCP/IP forwarding channel. This is usually done when a @@ -189,18 +268,42 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectTcpIp { host_to_connect: host_to_connect.into(), port_to_connect, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await + } + + /// Open a direct streamlocal (Unix domain socket) channel on the client. + pub async fn channel_open_direct_streamlocal>( + &self, + socket_path: A, + ) -> Result, Error> { + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + + self.sender + .send(Msg::ChannelOpenDirectStreamLocal { + socket_path: socket_path.into(), + channel_ref, + }) + .await + .map_err(|_| Error::SendError)?; + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_forwarded_tcpip, B: Into>( @@ -210,18 +313,41 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenForwardedTcpIp { connected_address: connected_address.into(), connected_port, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await + } + + pub async fn channel_open_forwarded_streamlocal>( + &self, + server_socket_path: A, + ) -> Result, Error> { + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + + self.sender + .send(Msg::ChannelOpenForwardedStreamLocal { + server_socket_path: server_socket_path.into(), + channel_ref, + }) + .await + .map_err(|_| Error::SendError)?; + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_x11>( @@ -229,21 +355,26 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenX11 { originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, + window_size_ref: WindowSizeRef, ) -> Result, Error> { loop { match receiver.recv().await { @@ -252,12 +383,16 @@ impl Handle { max_packet_size, window_size, }) => { + window_size_ref.update(window_size).await; + return Ok(Channel { - id, - sender: self.sender.clone(), - receiver, - max_packet_size, - window_size, + write_half: ChannelWriteHalf { + id, + sender: self.sender.clone(), + max_packet_size, + window_size: window_size_ref, + }, + read_half: ChannelReadHalf { receiver }, }); } Some(ChannelMsg::OpenFailure(reason)) => { @@ -295,14 +430,43 @@ impl Handle { .await .map_err(|_| ()) } + + /// Allows a server to disconnect a client session + pub async fn disconnect( + &self, + reason: Disconnect, + description: String, + language_tag: String, + ) -> Result<(), Error> { + self.sender + .send(Msg::Disconnect { + reason, + description, + language_tag, + }) + .await + .map_err(|_| Error::SendError) + } } impl Session { - pub(crate) fn is_rekeying(&self) -> bool { - if let Some(ref enc) = self.common.encrypted { - enc.rekey.is_some() + fn maybe_decompress(&mut self, buffer: &SSHBuffer) -> Result { + if let Some(ref mut enc) = self.common.encrypted { + let mut decomp = CryptoVec::new(); + Ok(IncomingSshPacket { + #[allow(clippy::indexing_slicing)] // length checked + buffer: enc.decompress.decompress( + &buffer.buffer[5..], + &mut decomp, + )?.into(), + seqn: buffer.seqn, + }) } else { - true + Ok(IncomingSshPacket { + #[allow(clippy::indexing_slicing)] // length checked + buffer: buffer.buffer[5..].into(), + seqn: buffer.seqn, + }) } } @@ -316,30 +480,35 @@ impl Session { R: AsyncRead + AsyncWrite + Unpin + Send + 'static, { self.flush()?; - stream - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; - self.common.write_buffer.buffer.clear(); + + map_err!(self.common.packet_writer.flush_into(&mut stream).await)?; let (stream_read, mut stream_write) = stream.split(); let buffer = SSHBuffer::new(); // Allow handing out references to the cipher let mut opening_cipher = Box::new(clear::Key) as Box; - std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local); + + let keepalive_timer = + future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); + pin!(keepalive_timer); + + let inactivity_timer = + future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); + pin!(inactivity_timer); let reading = start_reading(stream_read, buffer, opening_cipher); pin!(reading); let mut is_reading = None; - let mut decomp = CryptoVec::new(); - let delay = self.common.config.inactivity_timeout; #[allow(clippy::panic)] // false positive in macro while !self.common.disconnected { + self.common.received_data = false; + let mut sent_keepalive = false; tokio::select! { r = &mut reading => { - let (stream_read, buffer, mut opening_cipher) = match r { + let (stream_read, mut buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), Err(e) => return Err(e.into()) }; @@ -347,100 +516,113 @@ impl Session { is_reading = Some((stream_read, buffer, opening_cipher)); break } - #[allow(clippy::indexing_slicing)] // length checked - let buf = if let Some(ref mut enc) = self.common.encrypted { - let d = enc.decompress.decompress( - &buffer.buffer[5..], - &mut decomp, - ); - if let Ok(buf) = d { - buf - } else { - debug!("err = {:?}", d); - is_reading = Some((stream_read, buffer, opening_cipher)); - break - } - } else { - &buffer.buffer[5..] - }; - if !buf.is_empty() { - #[allow(clippy::indexing_slicing)] // length checked - if buf[0] == crate::msg::DISCONNECT { + + let mut pkt = self.maybe_decompress(&buffer)?; + + match pkt.buffer.first() { + None => (), + Some(&crate::msg::DISCONNECT) => { debug!("break"); is_reading = Some((stream_read, buffer, opening_cipher)); break; - } else if buf[0] > 4 { - std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + } + Some(_) => { + self.common.received_data = true; // TODO it'd be cleaner to just pass cipher to reply() - match reply(self, handler, buf).await { - Ok((h, s)) => { - handler = h; - self = s; - }, + std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local); + + match reply(&mut self, &mut handler, &mut pkt).await { + Ok(_) => {}, Err(e) => return Err(e), } - std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + buffer.seqn = pkt.seqn; // TODO reply changes seqn internall, find cleaner way + + std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local); } } reading.set(start_reading(stream_read, buffer, opening_cipher)); } - _ = timeout(delay) => { + () = &mut keepalive_timer => { + if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { + debug!("Timeout, client not responding to keepalives"); + return Err(crate::Error::KeepaliveTimeout.into()); + } + self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); + sent_keepalive = true; + self.keepalive_request()?; + } + () = &mut inactivity_timer => { debug!("timeout"); - break - }, - msg = self.receiver.recv(), if !self.is_rekeying() => { + return Err(crate::Error::InactivityTimeout.into()); + } + msg = self.receiver.recv(), if !self.kex.active() => { match msg { Some(Msg::Channel(id, ChannelMsg::Data { data })) => { - self.data(id, data); + self.data(id, data)?; } Some(Msg::Channel(id, ChannelMsg::ExtendedData { ext, data })) => { - self.extended_data(id, ext, data); + self.extended_data(id, ext, data)?; } Some(Msg::Channel(id, ChannelMsg::Eof)) => { - self.eof(id); + self.eof(id)?; } Some(Msg::Channel(id, ChannelMsg::Close)) => { - self.close(id); + self.close(id)?; } Some(Msg::Channel(id, ChannelMsg::Success)) => { - self.channel_success(id); + self.channel_success(id)?; } Some(Msg::Channel(id, ChannelMsg::Failure)) => { - self.channel_failure(id); + self.channel_failure(id)?; } Some(Msg::Channel(id, ChannelMsg::XonXoff { client_can_do })) => { - self.xon_xoff_request(id, client_can_do); + self.xon_xoff_request(id, client_can_do)?; } Some(Msg::Channel(id, ChannelMsg::ExitStatus { exit_status })) => { - self.exit_status_request(id, exit_status); + self.exit_status_request(id, exit_status)?; } Some(Msg::Channel(id, ChannelMsg::ExitSignal { signal_name, core_dumped, error_message, lang_tag })) => { - self.exit_signal_request(id, signal_name, core_dumped, &error_message, &lang_tag); + self.exit_signal_request(id, signal_name, core_dumped, &error_message, &lang_tag)?; } Some(Msg::Channel(id, ChannelMsg::WindowAdjusted { new_size })) => { debug!("window adjusted to {:?} for channel {:?}", new_size, id); } - Some(Msg::ChannelOpenSession { sender }) => { + Some(Msg::ChannelOpenAgent { channel_ref }) => { + let id = self.channel_open_agent()?; + self.channels.insert(id, channel_ref); + } + Some(Msg::ChannelOpenSession { channel_ref }) => { let id = self.channel_open_session()?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, channel_ref }) => { let id = self.channel_open_direct_tcpip(&host_to_connect, port_to_connect, &originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenForwardedTcpIp { connected_address, connected_port, originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenDirectStreamLocal { socket_path, channel_ref }) => { + let id = self.channel_open_direct_streamlocal(&socket_path)?; + self.channels.insert(id, channel_ref); + } + Some(Msg::ChannelOpenForwardedTcpIp { connected_address, connected_port, originator_address, originator_port, channel_ref }) => { let id = self.channel_open_forwarded_tcpip(&connected_address, connected_port, &originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); + } + Some(Msg::ChannelOpenForwardedStreamLocal { server_socket_path, channel_ref }) => { + let id = self.channel_open_forwarded_streamlocal(&server_socket_path)?; + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenX11 { originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenX11 { originator_address, originator_port, channel_ref }) => { let id = self.channel_open_x11(&originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::TcpIpForward { address, port }) => { - self.tcpip_forward(&address, port); + Some(Msg::TcpIpForward { address, port, reply_channel }) => { + self.tcpip_forward(&address, port, reply_channel)?; } - Some(Msg::CancelTcpIpForward { address, port }) => { - self.cancel_tcpip_forward(&address, port); + Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => { + self.cancel_tcpip_forward(&address, port, reply_channel)?; + } + Some(Msg::Disconnect {reason, description, language_tag}) => { + self.common.disconnect(reason, &description, &language_tag)?; } Some(_) => { // should be unreachable, since the receiver only gets @@ -454,23 +636,54 @@ impl Session { } } self.flush()?; - stream_write - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; - self.common.write_buffer.buffer.clear(); + + map_err!( + self.common + .packet_writer + .flush_into(&mut stream_write) + .await + )?; + + if self.common.received_data { + // Reset the number of failed keepalive attempts. We don't + // bother detecting keepalive response messages specifically + // (OpenSSH_9.6p1 responds with REQUEST_FAILURE aka 82). Instead + // we assume that the client is still alive if we receive any + // data from it. + self.common.alive_timeouts = 0; + } + if self.common.received_data || sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + keepalive_timer.as_mut().as_pin_mut(), + self.common.config.keepalive_interval, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } + if !sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + inactivity_timer.as_mut().as_pin_mut(), + self.common.config.inactivity_timeout, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } } debug!("disconnected"); // Shutdown - stream_write.shutdown().await.map_err(crate::Error::from)?; + map_err!(stream_write.shutdown().await)?; loop { if let Some((stream_read, buffer, opening_cipher)) = is_reading.take() { reading.set(start_reading(stream_read, buffer, opening_cipher)); } - let (n, r, b, opening_cipher) = (&mut reading).await?; - is_reading = Some((r, b, opening_cipher)); - if n == 0 { - break; + match (&mut reading).await { + Ok((0, _, _, _)) => break, + Ok((_, r, b, opening_cipher)) => { + is_reading = Some((r, b, opening_cipher)); + } + // at this stage of session shutdown, EOF is not unexpected + Err(Error::IO(ref e)) if e.kind() == ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), } } @@ -516,30 +729,23 @@ impl Session { if let Some(ref mut enc) = self.common.encrypted { if enc.flush( &self.common.config.as_ref().limits, - &mut *self.common.cipher.local_to_remote, - &mut self.common.write_buffer, - )? && enc.rekey.is_none() + &mut self.common.packet_writer, + )? && self.kex == SessionKexState::Idle { debug!("starting rekeying"); - if let Some(exchange) = enc.exchange.take() { - let mut kexinit = KexInit::initiate_rekey(exchange, &enc.session_id); - kexinit.server_write( - self.common.config.as_ref(), - &mut *self.common.cipher.local_to_remote, - &mut self.common.write_buffer, - )?; - enc.rekey = Some(Kex::Init(kexinit)) + if enc.exchange.take().is_some() { + self.begin_rekey()?; } } } Ok(()) } - pub fn flush_pending(&mut self, channel: ChannelId) -> usize { + pub fn flush_pending(&mut self, channel: ChannelId) -> Result { if let Some(ref mut enc) = self.common.encrypted { enc.flush_pending(channel) } else { - 0 + Ok(0) } } @@ -565,8 +771,40 @@ impl Session { } /// Sends a disconnect message. - pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) { - self.common.disconnect(reason, description, language_tag); + pub fn disconnect( + &mut self, + reason: Disconnect, + description: &str, + language_tag: &str, + ) -> Result<(), Error> { + self.common.disconnect(reason, description, language_tag) + } + + /// Sends a debug message to the client. + /// + /// Debug messages are intended for debugging purposes and may be + /// optionally displayed by the client, depending on the + /// `always_display` flag and client configuration. + /// + /// # Parameters + /// + /// - `always_display`: If `true`, the client is encouraged to + /// display the message regardless of user preferences. + /// - `message`: The debug message to be sent. + /// - `language_tag`: The language tag of the message. + /// + /// # Notes + /// + /// This message is informational and does not affect the SSH session + /// state. Most clients (e.g., OpenSSH) will only display the message + /// if verbose mode is enabled. + pub fn debug( + &mut self, + always_display: bool, + message: &str, + language_tag: &str, + ) -> Result<(), Error> { + self.common.debug(always_display, message, language_tag) } /// Send a "success" reply to a /global/ request (requests without @@ -593,7 +831,7 @@ impl Session { /// Send a "success" reply to a channel request. Always call this /// function if the request was successful (it checks whether the /// client expects an answer). - pub fn channel_success(&mut self, channel: ChannelId) { + pub fn channel_success(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel) { assert!(channel.confirmed); @@ -601,16 +839,17 @@ impl Session { channel.wants_reply = false; debug!("channel_success {:?}", channel); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_SUCCESS); - enc.write.push_u32_be(channel.recipient_channel); + msg::CHANNEL_SUCCESS.encode(&mut enc.write)?; + channel.recipient_channel.encode(&mut enc.write)?; }) } } } + Ok(()) } /// Send a "failure" reply to a global request. - pub fn channel_failure(&mut self, channel: ChannelId) { + pub fn channel_failure(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel) { assert!(channel.confirmed); @@ -618,11 +857,12 @@ impl Session { channel.wants_reply = false; push_packet!(enc.write, { enc.write.push(msg::CHANNEL_FAILURE); - enc.write.push_u32_be(channel.recipient_channel); + channel.recipient_channel.encode(&mut enc.write)?; }) } } } + Ok(()) } /// Send a "failure" reply to a request to open a channel open. @@ -632,26 +872,27 @@ impl Session { reason: ChannelOpenFailure, description: &str, language: &str, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { push_packet!(enc.write, { enc.write.push(msg::CHANNEL_OPEN_FAILURE); - enc.write.push_u32_be(channel.0); - enc.write.push_u32_be(reason as u32); - enc.write.extend_ssh_string(description.as_bytes()); - enc.write.extend_ssh_string(language.as_bytes()); + channel.encode(&mut enc.write)?; + (reason as u32).encode(&mut enc.write)?; + description.encode(&mut enc.write)?; + language.encode(&mut enc.write)?; }) } + Ok(()) } /// Close a channel. - pub fn close(&mut self, channel: ChannelId) { - self.common.byte(channel, msg::CHANNEL_CLOSE); + pub fn close(&mut self, channel: ChannelId) -> Result<(), Error> { + self.common.byte(channel, msg::CHANNEL_CLOSE) } /// Send EOF to a channel - pub fn eof(&mut self, channel: ChannelId) { - self.common.byte(channel, msg::CHANNEL_EOF); + pub fn eof(&mut self, channel: ChannelId) -> Result<(), Error> { + self.common.byte(channel, msg::CHANNEL_EOF) } /// Send data to a channel. On session channels, `extended` can be @@ -660,9 +901,9 @@ impl Session { /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned. - pub fn data(&mut self, channel: ChannelId, data: CryptoVec) { + pub fn data(&mut self, channel: ChannelId, data: CryptoVec) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { - enc.data(channel, data) + enc.data(channel, data, self.kex.active()) } else { unreachable!() } @@ -674,9 +915,14 @@ impl Session { /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned. - pub fn extended_data(&mut self, channel: ChannelId, extended: u32, data: CryptoVec) { + pub fn extended_data( + &mut self, + channel: ChannelId, + extended: u32, + data: CryptoVec, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { - enc.extended_data(channel, extended, data) + enc.extended_data(channel, extended, data, self.kex.active()) } else { unreachable!() } @@ -685,37 +931,62 @@ impl Session { /// Inform the client of whether they may perform /// control-S/control-Q flow control. See /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.8). - pub fn xon_xoff_request(&mut self, channel: ChannelId, client_can_do: bool) { + pub fn xon_xoff_request( + &mut self, + channel: ChannelId, + client_can_do: bool, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"xon-xoff"); - enc.write.push(0); - enc.write.push(client_can_do as u8); + channel.recipient_channel.encode(&mut enc.write)?; + "xon-xoff".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + (client_can_do as u8).encode(&mut enc.write)?; }) } } + Ok(()) + } + + /// Ping the client to verify there is still connectivity. + pub fn keepalive_request(&mut self) -> Result<(), Error> { + let want_reply = u8::from(true); + if let Some(ref mut enc) = self.common.encrypted { + self.open_global_requests + .push_back(GlobalRequestResponse::Keepalive); + push_packet!(enc.write, { + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "keepalive@openssh.com".encode(&mut enc.write)?; + want_reply.encode(&mut enc.write)?; + }) + } + Ok(()) } /// Send the exit status of a program. - pub fn exit_status_request(&mut self, channel: ChannelId, exit_status: u32) { + pub fn exit_status_request( + &mut self, + channel: ChannelId, + exit_status: u32, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"exit-status"); - enc.write.push(0); - enc.write.push_u32_be(exit_status) + channel.recipient_channel.encode(&mut enc.write)?; + "exit-status".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + exit_status.encode(&mut enc.write)?; }) } } + Ok(()) } /// If the program was killed by a signal, send the details about the signal to the client. @@ -726,31 +997,32 @@ impl Session { core_dumped: bool, error_message: &str, language_tag: &str, - ) { + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"exit-signal"); - enc.write.push(0); - enc.write.extend_ssh_string(signal.name().as_bytes()); - enc.write.push(core_dumped as u8); - enc.write.extend_ssh_string(error_message.as_bytes()); - enc.write.extend_ssh_string(language_tag.as_bytes()); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + + channel.recipient_channel.encode(&mut enc.write)?; + "exit-signal".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + signal.name().encode(&mut enc.write)?; + (core_dumped as u8).encode(&mut enc.write)?; + error_message.encode(&mut enc.write)?; + language_tag.encode(&mut enc.write)?; }) } } + Ok(()) } /// Opens a new session channel on the client. pub fn channel_open_session(&mut self) -> Result { - self.channel_open_generic(b"session", |_| ()) + self.channel_open_generic(b"session", |_| Ok(())) } - /// Opens a direct TCP/IP channel on the client. + /// Opens a direct-tcpip channel on the client (non-standard). pub fn channel_open_direct_tcpip( &mut self, host_to_connect: &str, @@ -759,10 +1031,24 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"direct-tcpip", |write| { - write.extend_ssh_string(host_to_connect.as_bytes()); - write.push_u32_be(port_to_connect); // sender channel id. - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + host_to_connect.encode(write)?; + port_to_connect.encode(write)?; // sender channel id. + originator_address.encode(write)?; + originator_port.encode(write)?; // sender channel id. + Ok(()) + }) + } + + /// Opens a direct-streamlocal channel on the client (non-standard). + pub fn channel_open_direct_streamlocal( + &mut self, + socket_path: &str, + ) -> Result { + self.channel_open_generic(b"direct-streamlocal@openssh.com", |write| { + socket_path.encode(write)?; + "".encode(write)?; // reserved + 0u32.encode(write)?; // reserved + Ok(()) }) } @@ -779,10 +1065,22 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"forwarded-tcpip", |write| { - write.extend_ssh_string(connected_address.as_bytes()); - write.push_u32_be(connected_port); // sender channel id. - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + connected_address.encode(write)?; + connected_port.encode(write)?; // sender channel id. + originator_address.encode(write)?; + originator_port.encode(write)?; // sender channel id. + Ok(()) + }) + } + + pub fn channel_open_forwarded_streamlocal( + &mut self, + socket_path: &str, + ) -> Result { + self.channel_open_generic(b"forwarded-streamlocal@openssh.com", |write| { + socket_path.encode(write)?; + "".encode(write)?; + Ok(()) }) } @@ -795,19 +1093,20 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"x11", |write| { - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); + originator_address.encode(write)?; + originator_port.encode(write)?; + Ok(()) }) } /// Opens a new agent channel on the client. pub fn channel_open_agent(&mut self) -> Result { - self.channel_open_generic(b"auth-agent@openssh.com", |_| ()) + self.channel_open_generic(b"auth-agent@openssh.com", |_| Ok(())) } fn channel_open_generic(&mut self, kind: &[u8], write_suffix: F) -> Result where - F: FnOnce(&mut CryptoVec), + F: FnOnce(&mut CryptoVec) -> Result<(), Error>, { let result = if let Some(ref mut enc) = self.common.encrypted { if !matches!( @@ -823,20 +1122,26 @@ impl Session { ); push_packet!(enc.write, { enc.write.push(msg::CHANNEL_OPEN); - enc.write.extend_ssh_string(kind); + kind.encode(&mut enc.write)?; // sender channel id. - enc.write.push_u32_be(sender_channel.0); + sender_channel.encode(&mut enc.write)?; // window. - enc.write - .push_u32_be(self.common.config.as_ref().window_size); + self.common + .config + .as_ref() + .window_size + .encode(&mut enc.write)?; // max packet size. - enc.write - .push_u32_be(self.common.config.as_ref().maximum_packet_size); + self.common + .config + .as_ref() + .maximum_packet_size + .encode(&mut enc.write)?; - write_suffix(&mut enc.write); + write_suffix(&mut enc.write)?; }); sender_channel } else { @@ -848,59 +1153,129 @@ impl Session { /// Requests that the client forward connections to the given host and port. /// See [RFC4254](https://tools.ietf.org/html/rfc4254#section-7). The client /// will open forwarded_tcpip channels for each connection. - pub fn tcpip_forward(&mut self, address: &str, port: u32) { + pub fn tcpip_forward( + &mut self, + address: &str, + port: u32, + reply_channel: Option>>, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::TcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"tcpip-forward"); - enc.write.push(0); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + "tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } /// Cancels a previously tcpip_forward request. - pub fn cancel_tcpip_forward(&mut self, address: &str, port: u32) { + pub fn cancel_tcpip_forward( + &mut self, + address: &str, + port: u32, + reply_channel: Option>, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { + let want_reply = reply_channel.is_some(); + if let Some(reply_channel) = reply_channel { + self.open_global_requests.push_back( + crate::session::GlobalRequestResponse::CancelTcpIpForward(reply_channel), + ); + } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"cancel-tcpip-forward"); - enc.write.push(0); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "cancel-tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } - pub(crate) fn maybe_send_ext_info(&mut self) { + /// Returns the SSH ID (Protocol Version + Software Version) the client sent when connecting + /// + /// This should contain only ASCII characters for implementations conforming to RFC4253, Section 4.2: + /// + /// > Both the 'protoversion' and 'softwareversion' strings MUST consist of + /// > printable US-ASCII characters, with the exception of whitespace + /// > characters and the minus sign (-). + /// + /// So it usually is fine to convert it to a [`String`] using [`String::from_utf8_lossy`] + pub fn remote_sshid(&self) -> &[u8] { + &self.common.remote_sshid + } + + pub(crate) fn maybe_send_ext_info(&mut self) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { // If client sent a ext-info-c message in the kex list, it supports RFC 8308 extension negotiation. let mut key_extension_client = false; if let Some(e) = &enc.exchange { - let mut r = e.client_kex_init.as_ref().reader(17); - if let Ok(kex_string) = r.read_string() { + let Some(mut r) = &e.client_kex_init.as_ref().get(17..) else { + return Ok(()); + }; + if let Ok(kex_string) = String::decode(&mut r) { use super::negotiation::Select; key_extension_client = super::negotiation::Server::select( &[EXTENSION_SUPPORT_AS_CLIENT], - kex_string, + &parse_kex_algo_list(&kex_string), + AlgorithmKind::Kex, ) - .is_some(); + .is_ok(); } } if !key_extension_client { debug!("RFC 8308 Extension Negotiation not supported by client"); - return; + return Ok(()); } push_packet!(enc.write, { - enc.write.push(msg::EXT_INFO); - enc.write.push_u32_be(1); - enc.write.extend_ssh_string(b"server-sig-algs"); - enc.write - .extend_list(self.common.config.preferred.key.iter()); + msg::EXT_INFO.encode(&mut enc.write)?; + 1u32.encode(&mut enc.write)?; + "server-sig-algs".encode(&mut enc.write)?; + + NameList( + self.common + .config + .preferred + .key + .iter() + .map(|x| x.to_string()) + .collect(), + ) + .encode(&mut enc.write)?; }); } + Ok(()) + } + + pub(crate) fn begin_rekey(&mut self) -> Result<(), Error> { + debug!("beginning re-key"); + let mut kex = ServerKex::new( + self.common.config.clone(), + &self.common.remote_sshid, + &self.common.config.server_id, + match self.common.encrypted { + None => KexCause::Initial, + Some(ref enc) => KexCause::Rekey { + strict: self.common.strict_kex, + session_id: enc.session_id.clone(), + }, + }, + ); + + kex.kexinit(&mut self.common.packet_writer)?; + self.kex = SessionKexState::InProgress(kex); + Ok(()) } } diff --git a/russh/src/session.rs b/russh/src/session.rs index 09afa95a..661306dc 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -15,17 +15,22 @@ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; +use std::mem::replace; use std::num::Wrapping; use byteorder::{BigEndian, ByteOrder}; use log::{debug, trace}; -use russh_cryptovec::CryptoVec; -use russh_keys::encoding::Encoding; - -use crate::cipher::SealingKey; -use crate::kex::KexAlgorithm; -use crate::sshbuffer::SSHBuffer; -use crate::{auth, cipher, mac, msg, negotiation, ChannelId, ChannelParams, Disconnect, Limits}; +use ssh_encoding::Encode; +use tokio::sync::oneshot; + +use crate::cipher::OpeningKey; +use crate::client::GexParams; +use crate::kex::dh::groups::DhGroup; +use crate::kex::{KexAlgorithm, KexAlgorithmImplementor}; +use crate::sshbuffer::PacketWriter; +use crate::{ + auth, cipher, mac, msg, negotiation, ChannelId, ChannelParams, CryptoVec, Disconnect, Limits, +}; #[derive(Debug)] pub(crate) struct Encrypted { @@ -33,36 +38,86 @@ pub(crate) struct Encrypted { // It's always Some, except when we std::mem::replace it temporarily. pub exchange: Option, - pub kex: Box, + pub kex: KexAlgorithm, pub key: usize, pub client_mac: mac::Name, pub server_mac: mac::Name, pub session_id: CryptoVec, - pub rekey: Option, pub channels: HashMap, pub last_channel_id: Wrapping, pub write: CryptoVec, pub write_cursor: usize, - pub last_rekey: std::time::Instant, + pub last_rekey: russh_util::time::Instant, pub server_compression: crate::compression::Compression, pub client_compression: crate::compression::Compression, - pub compress: crate::compression::Compress, pub decompress: crate::compression::Decompress, - pub compress_buffer: CryptoVec, + pub rekey_wanted: bool, + pub received_extensions: Vec, + pub extension_info_awaiters: HashMap>>, } pub(crate) struct CommonSession { pub auth_user: String, + pub remote_sshid: Vec, pub config: Config, pub encrypted: Option, pub auth_method: Option, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub(crate) auth_attempts: usize, - pub write_buffer: SSHBuffer, - pub kex: Option, - pub cipher: cipher::CipherPair, + pub packet_writer: PacketWriter, + pub remote_to_local: Box, pub wants_reply: bool, pub disconnected: bool, pub buffer: CryptoVec, + pub strict_kex: bool, + pub alive_timeouts: usize, + pub received_data: bool, +} + +impl Debug for CommonSession { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommonSession") + .field("auth_user", &self.auth_user) + .field("remote_sshid", &self.remote_sshid) + .field("encrypted", &self.encrypted) + .field("auth_method", &self.auth_method) + .field("auth_attempts", &self.auth_attempts) + .field("packet_writer", &self.packet_writer) + .field("wants_reply", &self.wants_reply) + .field("disconnected", &self.disconnected) + .field("buffer", &self.buffer) + .field("strict_kex", &self.strict_kex) + .field("alive_timeouts", &self.alive_timeouts) + .field("received_data", &self.received_data) + .finish() + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) enum ChannelFlushResult { + Incomplete { + wrote: usize, + }, + Complete { + wrote: usize, + pending_eof: bool, + pending_close: bool, + }, +} +impl ChannelFlushResult { + pub(crate) fn wrote(&self) -> usize { + match self { + ChannelFlushResult::Incomplete { wrote } => *wrote, + ChannelFlushResult::Complete { wrote, .. } => *wrote, + } + } + pub(crate) fn complete(wrote: usize, channel: &ChannelParams) -> Self { + ChannelFlushResult::Complete { + wrote, + pending_eof: channel.pending_eof, + pending_close: channel.pending_close, + } + } } impl CommonSession { @@ -73,7 +128,10 @@ impl CommonSession { enc.key = newkeys.key; enc.client_mac = newkeys.names.client_mac; enc.server_mac = newkeys.names.server_mac; - self.cipher = newkeys.cipher; + self.remote_to_local = newkeys.cipher.remote_to_local; + self.packet_writer + .set_cipher(newkeys.cipher.local_to_remote); + self.strict_kex = self.strict_kex || newkeys.names.strict_kex; } } @@ -86,73 +144,116 @@ impl CommonSession { server_mac: newkeys.names.server_mac, session_id: newkeys.session_id, state, - rekey: None, channels: HashMap::new(), last_channel_id: Wrapping(1), write: CryptoVec::new(), write_cursor: 0, - last_rekey: std::time::Instant::now(), + last_rekey: russh_util::time::Instant::now(), server_compression: newkeys.names.server_compression, client_compression: newkeys.names.client_compression, - compress: crate::compression::Compress::None, - compress_buffer: CryptoVec::new(), decompress: crate::compression::Decompress::None, + rekey_wanted: false, + received_extensions: Vec::new(), + extension_info_awaiters: HashMap::new(), }); - self.cipher = newkeys.cipher; + self.remote_to_local = newkeys.cipher.remote_to_local; + self.packet_writer + .set_cipher(newkeys.cipher.local_to_remote); + self.strict_kex = newkeys.names.strict_kex; } /// Send a disconnect message. - pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) { + pub fn disconnect( + &mut self, + reason: Disconnect, + description: &str, + language_tag: &str, + ) -> Result<(), crate::Error> { let disconnect = |buf: &mut CryptoVec| { push_packet!(buf, { - buf.push(msg::DISCONNECT); - buf.push_u32_be(reason as u32); - buf.extend_ssh_string(description.as_bytes()); - buf.extend_ssh_string(language_tag.as_bytes()); + msg::DISCONNECT.encode(buf)?; + (reason as u32).encode(buf)?; + description.encode(buf)?; + language_tag.encode(buf)?; }); + Ok(()) }; if !self.disconnected { self.disconnected = true; - if let Some(ref mut enc) = self.encrypted { + return if let Some(ref mut enc) = self.encrypted { disconnect(&mut enc.write) } else { - disconnect(&mut self.write_buffer.buffer) - } + disconnect(&mut self.packet_writer.buffer().buffer) + }; } + Ok(()) + } + + /// Send a debug message. + pub fn debug( + &mut self, + always_display: bool, + message: &str, + language_tag: &str, + ) -> Result<(), crate::Error> { + let debug = |buf: &mut CryptoVec| { + push_packet!(buf, { + msg::DEBUG.encode(buf)?; + (always_display as u8).encode(buf)?; + message.encode(buf)?; + language_tag.encode(buf)?; + }); + Ok(()) + }; + return if let Some(ref mut enc) = self.encrypted { + debug(&mut enc.write) + } else { + debug(&mut self.packet_writer.buffer().buffer) + }; } /// Send a single byte message onto the channel. - pub fn byte(&mut self, channel: ChannelId, msg: u8) { + #[cfg(not(target_arch = "wasm32"))] + pub fn byte(&mut self, channel: ChannelId, msg: u8) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.encrypted { - enc.byte(channel, msg) + enc.byte(channel, msg)? } + Ok(()) + } + + pub(crate) fn reset_seqn(&mut self) { + self.packet_writer.reset_seqn(); } } impl Encrypted { - pub fn byte(&mut self, channel: ChannelId, msg: u8) { + pub fn byte(&mut self, channel: ChannelId, msg: u8) -> Result<(), crate::Error> { if let Some(channel) = self.channels.get(&channel) { push_packet!(self.write, { self.write.push(msg); - self.write.push_u32_be(channel.recipient_channel); + channel.recipient_channel.encode(&mut self.write)?; }); } + Ok(()) } - /* - pub fn authenticated(&mut self) { - self.server_compression.init_compress(&mut self.compress); - self.state = EncryptedState::Authenticated; - } - */ - - pub fn eof(&mut self, channel: ChannelId) { - self.byte(channel, msg::CHANNEL_EOF); + pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> { + if let Some(channel) = self.has_pending_data_mut(channel) { + channel.pending_eof = true; + } else { + self.byte(channel, msg::CHANNEL_EOF)?; + } + Ok(()) } - pub fn close(&mut self, channel: ChannelId) { - self.byte(channel, msg::CHANNEL_CLOSE); - self.channels.remove(&channel); + pub fn close(&mut self, channel: ChannelId) -> Result<(), crate::Error> { + if let Some(channel) = self.has_pending_data_mut(channel) { + channel.pending_close = true; + } else { + self.byte(channel, msg::CHANNEL_CLOSE)?; + self.channels.remove(&channel); + } + Ok(()) } pub fn sender_window_size(&self, channel: ChannelId) -> usize { @@ -163,7 +264,12 @@ impl Encrypted { } } - pub fn adjust_window_size(&mut self, channel: ChannelId, data: &[u8], target: u32) -> bool { + pub fn adjust_window_size( + &mut self, + channel: ChannelId, + data: &[u8], + target: u32, + ) -> Result { if let Some(channel) = self.channels.get_mut(&channel) { trace!( "adjust_window_size, channel = {}, size = {},", @@ -182,41 +288,81 @@ impl Encrypted { ); push_packet!(self.write, { self.write.push(msg::CHANNEL_WINDOW_ADJUST); - self.write.push_u32_be(channel.recipient_channel); - self.write.push_u32_be(target - channel.sender_window_size); + channel.recipient_channel.encode(&mut self.write)?; + (target - channel.sender_window_size).encode(&mut self.write)?; }); channel.sender_window_size = target; - return true; + return Ok(true); } } - false + Ok(false) } - pub fn flush_pending(&mut self, channel: ChannelId) -> usize { + fn flush_channel( + write: &mut CryptoVec, + channel: &mut ChannelParams, + ) -> Result { let mut pending_size = 0; - if let Some(channel) = self.channels.get_mut(&channel) { - while let Some((buf, a, from)) = channel.pending_data.pop_front() { - let size = Self::data_noqueue(&mut self.write, channel, &buf, from); - pending_size += size; - if from + size < buf.len() { - channel.pending_data.push_front((buf, a, from + size)); - break; - } + while let Some((buf, a, from)) = channel.pending_data.pop_front() { + let size = Self::data_noqueue(write, channel, &buf, a, from)?; + pending_size += size; + if from + size < buf.len() { + channel.pending_data.push_front((buf, a, from + size)); + return Ok(ChannelFlushResult::Incomplete { + wrote: pending_size, + }); } } - pending_size + Ok(ChannelFlushResult::complete(pending_size, channel)) } - pub fn flush_all_pending(&mut self) { - for (_, channel) in self.channels.iter_mut() { - while let Some((buf, a, from)) = channel.pending_data.pop_front() { - let size = Self::data_noqueue(&mut self.write, channel, &buf, from); - if from + size < buf.len() { - channel.pending_data.push_front((buf, a, from + size)); - break; - } + fn handle_flushed_channel( + &mut self, + channel: ChannelId, + flush_result: ChannelFlushResult, + ) -> Result<(), crate::Error> { + if let ChannelFlushResult::Complete { + wrote: _, + pending_eof, + pending_close, + } = flush_result + { + if pending_eof { + self.eof(channel)?; } + if pending_close { + self.close(channel)?; + } + } + Ok(()) + } + + pub fn flush_pending(&mut self, channel: ChannelId) -> Result { + let mut pending_size = 0; + let mut maybe_flush_result = Option::::None; + + if let Some(channel) = self.channels.get_mut(&channel) { + let flush_result = Self::flush_channel(&mut self.write, channel)?; + pending_size += flush_result.wrote(); + maybe_flush_result = Some(flush_result); + } + if let Some(flush_result) = maybe_flush_result { + self.handle_flushed_channel(channel, flush_result)? + } + Ok(pending_size) + } + + pub fn flush_all_pending(&mut self) -> Result<(), crate::Error> { + for channel in self.channels.values_mut() { + Self::flush_channel(&mut self.write, channel)?; } + Ok(()) + } + + fn has_pending_data_mut(&mut self, channel: ChannelId) -> Option<&mut ChannelParams> { + self.channels + .get_mut(&channel) + .filter(|c| !c.pending_data.is_empty()) } pub fn has_pending_data(&self, channel: ChannelId) -> bool { @@ -234,10 +380,11 @@ impl Encrypted { write: &mut CryptoVec, channel: &mut ChannelParams, buf0: &[u8], + a: Option, from: usize, - ) -> usize { + ) -> Result { if from >= buf0.len() { - return 0; + return Ok(0); } let mut buf = if buf0.len() as u32 > from as u32 + channel.recipient_window_size { #[allow(clippy::indexing_slicing)] // length checked @@ -251,12 +398,21 @@ impl Encrypted { while !buf.is_empty() { // Compute the length we're allowed to send. let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize); - push_packet!(write, { - write.push(msg::CHANNEL_DATA); - write.push_u32_be(channel.recipient_channel); - #[allow(clippy::indexing_slicing)] // length checked - write.extend_ssh_string(&buf[..off]); - }); + match a { + None => push_packet!(write, { + write.push(msg::CHANNEL_DATA); + channel.recipient_channel.encode(write)?; + #[allow(clippy::indexing_slicing)] // length checked + buf[..off].encode(write)?; + }), + Some(ext) => push_packet!(write, { + write.push(msg::CHANNEL_EXTENDED_DATA); + channel.recipient_channel.encode(write)?; + ext.encode(write)?; + #[allow(clippy::indexing_slicing)] // length checked + buf[..off].encode(write)?; + }), + } trace!( "buffer: {:?} {:?}", write.len(), @@ -269,70 +425,56 @@ impl Encrypted { } } trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len); - buf_len + Ok(buf_len) } - pub fn data(&mut self, channel: ChannelId, buf0: CryptoVec) { + pub fn data( + &mut self, + channel: ChannelId, + buf0: CryptoVec, + is_rekeying: bool, + ) -> Result<(), crate::Error> { if let Some(channel) = self.channels.get_mut(&channel) { assert!(channel.confirmed); - if !channel.pending_data.is_empty() || self.rekey.is_some() { + if !channel.pending_data.is_empty() && is_rekeying { channel.pending_data.push_back((buf0, None, 0)); - return; + return Ok(()); } - let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, 0); + let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0)?; if buf_len < buf0.len() { channel.pending_data.push_back((buf0, None, buf_len)) } } else { debug!("{:?} not saved for this session", channel); } + Ok(()) } - pub fn extended_data(&mut self, channel: ChannelId, ext: u32, buf0: CryptoVec) { - use std::ops::Deref; + pub fn extended_data( + &mut self, + channel: ChannelId, + ext: u32, + buf0: CryptoVec, + is_rekeying: bool, + ) -> Result<(), crate::Error> { if let Some(channel) = self.channels.get_mut(&channel) { assert!(channel.confirmed); - if !channel.pending_data.is_empty() { + if !channel.pending_data.is_empty() && is_rekeying { channel.pending_data.push_back((buf0, Some(ext), 0)); - return; + return Ok(()); } - let mut buf = if buf0.len() as u32 > channel.recipient_window_size { - #[allow(clippy::indexing_slicing)] // length checked - &buf0[0..channel.recipient_window_size as usize] - } else { - &buf0 - }; - let buf_len = buf.len(); - - while !buf.is_empty() { - // Compute the length we're allowed to send. - let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize); - push_packet!(self.write, { - self.write.push(msg::CHANNEL_EXTENDED_DATA); - self.write.push_u32_be(channel.recipient_channel); - self.write.push_u32_be(ext); - #[allow(clippy::indexing_slicing)] // length checked - self.write.extend_ssh_string(&buf[..off]); - }); - trace!("buffer: {:?}", self.write.deref().len()); - channel.recipient_window_size -= off as u32; - #[allow(clippy::indexing_slicing)] // length checked - { - buf = &buf[off..] - } - } - trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len); + let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0)?; if buf_len < buf0.len() { channel.pending_data.push_back((buf0, Some(ext), buf_len)) } } + Ok(()) } pub fn flush( &mut self, limits: &Limits, - cipher: &mut dyn SealingKey, - write_buffer: &mut SSHBuffer, + writer: &mut PacketWriter, ) -> Result { // If there are pending packets (and we've not started to rekey), flush them. { @@ -342,12 +484,9 @@ impl Encrypted { let len = BigEndian::read_u32(&self.write[self.write_cursor..]) as usize; #[allow(clippy::indexing_slicing)] let to_write = &self.write[(self.write_cursor + 4)..(self.write_cursor + 4 + len)]; - trace!("server_write_encrypted, buf = {:?}", to_write); - #[allow(clippy::indexing_slicing)] - let packet = self - .compress - .compress(to_write, &mut self.compress_buffer)?; - cipher.write(packet, write_buffer); + trace!("session_write_encrypted, buf = {:?}", to_write); + + writer.packet_raw(to_write)?; self.write_cursor += 4 + len } } @@ -361,10 +500,13 @@ impl Encrypted { return Ok(false); } - let now = std::time::Instant::now(); + let now = russh_util::time::Instant::now(); let dur = now.duration_since(self.last_rekey); - Ok(write_buffer.bytes >= limits.rekey_write_limit || dur >= limits.rekey_time_limit) + Ok(replace(&mut self.rekey_wanted, false) + || writer.buffer().bytes >= limits.rekey_write_limit + || dur >= limits.rekey_time_limit) } + pub fn new_channel_id(&mut self) -> ChannelId { self.last_channel_id += Wrapping(1); while self @@ -391,6 +533,8 @@ impl Encrypted { confirmed: false, wants_reply: false, pending_data: std::collections::VecDeque::new(), + pending_eof: false, + pending_close: false, }); return ChannelId(self.last_channel_id.0); } @@ -406,7 +550,7 @@ pub enum EncryptedState { Authenticated, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct Exchange { pub client_id: CryptoVec, pub server_id: CryptoVec, @@ -414,141 +558,40 @@ pub struct Exchange { pub server_kex_init: CryptoVec, pub client_ephemeral: CryptoVec, pub server_ephemeral: CryptoVec, + pub gex: Option<(GexParams, DhGroup)>, } impl Exchange { - pub fn new() -> Self { + pub fn new(client_id: &[u8], server_id: &[u8]) -> Self { Exchange { - client_id: CryptoVec::new(), - server_id: CryptoVec::new(), - client_kex_init: CryptoVec::new(), - server_kex_init: CryptoVec::new(), - client_ephemeral: CryptoVec::new(), - server_ephemeral: CryptoVec::new(), + client_id: client_id.into(), + server_id: server_id.into(), + ..Default::default() } } } -#[derive(Debug)] -pub(crate) enum Kex { - /// Version number sent. `algo` and `sent` tell wether kexinit has - /// been received, and sent, respectively. - Init(KexInit), - - /// Algorithms have been determined, the DH algorithm should run. - Dh(KexDh), - - /// The kex has run. - DhDone(KexDhDone), - - /// The DH is over, we've sent the NEWKEYS packet, and are waiting - /// the NEWKEYS from the other side. - Keys(NewKeys), -} - -#[derive(Debug)] -pub(crate) struct KexInit { - pub algo: Option, - pub exchange: Exchange, - pub session_id: Option, - pub sent: bool, -} - -impl KexInit { - pub fn received_rekey(ex: Exchange, algo: negotiation::Names, session_id: &CryptoVec) -> Self { - let mut kexinit = KexInit { - exchange: ex, - algo: Some(algo), - sent: false, - session_id: Some(session_id.clone()), - }; - kexinit.exchange.client_kex_init.clear(); - kexinit.exchange.server_kex_init.clear(); - kexinit.exchange.client_ephemeral.clear(); - kexinit.exchange.server_ephemeral.clear(); - kexinit - } - - pub fn initiate_rekey(ex: Exchange, session_id: &CryptoVec) -> Self { - let mut kexinit = KexInit { - exchange: ex, - algo: None, - sent: true, - session_id: Some(session_id.clone()), - }; - kexinit.exchange.client_kex_init.clear(); - kexinit.exchange.server_kex_init.clear(); - kexinit.exchange.client_ephemeral.clear(); - kexinit.exchange.server_ephemeral.clear(); - kexinit - } -} - -#[derive(Debug)] -pub(crate) struct KexDh { - pub exchange: Exchange, - pub names: negotiation::Names, - pub key: usize, - pub session_id: Option, -} - -pub(crate) struct KexDhDone { - pub exchange: Exchange, - pub kex: Box, - pub key: usize, - pub session_id: Option, - pub names: negotiation::Names, -} - -impl Debug for KexDhDone { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "KexDhDone") - } -} - -impl KexDhDone { - pub fn compute_keys(self, hash: CryptoVec, is_server: bool) -> Result { - let session_id = if let Some(session_id) = self.session_id { - session_id - } else { - hash.clone() - }; - // Now computing keys. - let c = self.kex.compute_keys( - &session_id, - &hash, - self.names.cipher, - if is_server { - self.names.client_mac - } else { - self.names.server_mac - }, - if is_server { - self.names.server_mac - } else { - self.names.client_mac - }, - is_server, - )?; - Ok(NewKeys { - exchange: self.exchange, - names: self.names, - kex: self.kex, - key: self.key, - cipher: c, - session_id, - sent: false, - }) - } -} - #[derive(Debug)] pub(crate) struct NewKeys { pub exchange: Exchange, pub names: negotiation::Names, - pub kex: Box, + pub kex: KexAlgorithm, pub key: usize, pub cipher: cipher::CipherPair, pub session_id: CryptoVec, - pub sent: bool, +} + +#[derive(Debug)] +pub(crate) enum GlobalRequestResponse { + /// request was for Keepalive, ignore result + Keepalive, + /// request was for NoMoreSessions, disallow additional sessions + NoMoreSessions, + /// request was for TcpIpForward, sends Some(port) for success or None for failure + TcpIpForward(oneshot::Sender>), + /// request was for CancelTcpIpForward, sends true for success or false for failure + CancelTcpIpForward(oneshot::Sender), + /// request was for StreamLocalForward, sends true for success or false for failure + StreamLocalForward(oneshot::Sender), + CancelStreamLocalForward(oneshot::Sender), } diff --git a/russh/src/ssh_read.rs b/russh/src/ssh_read.rs index d74b19dd..c02476fb 100644 --- a/russh/src/ssh_read.rs +++ b/russh/src/ssh_read.rs @@ -1,11 +1,10 @@ use std::pin::Pin; use futures::task::*; -use russh_cryptovec::CryptoVec; +use log::trace; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; -use log::debug; -use crate::Error; +use crate::{CryptoVec, Error}; /// The buffer to read the identification string (first line in the /// protocol). @@ -62,7 +61,7 @@ impl AsyncRead for SshRead { buf: &mut ReadBuf, ) -> Poll> { if let Some(mut id) = self.id.take() { - debug!("id {:?} {:?}", id.total, id.bytes_read); + trace!("id {:?} {:?}", id.total, id.bytes_read); if id.total > id.bytes_read { let total = id.total.min(id.bytes_read + buf.remaining()); #[allow(clippy::indexing_slicing)] // length checked @@ -119,16 +118,16 @@ impl SshRead { let ssh_id = self.id.as_mut().unwrap(); loop { let mut i = 0; - debug!("read_ssh_id: reading"); + trace!("read_ssh_id: reading"); #[allow(clippy::indexing_slicing)] // length checked let n = AsyncReadExt::read(&mut self.r, &mut ssh_id.buf[ssh_id.total..]).await?; - debug!("read {:?}", n); + trace!("read {:?}", n); ssh_id.total += n; #[allow(clippy::indexing_slicing)] // length checked { - debug!("{:?}", std::str::from_utf8(&ssh_id.buf[..ssh_id.total])); + trace!("{:?}", std::str::from_utf8(&ssh_id.buf[..ssh_id.total])); } if n == 0 { return Err(Error::Disconnect); @@ -154,11 +153,15 @@ impl SshRead { if ssh_id.bytes_read > 0 { // If we have a full line, handle it. - if i >= 8 && ssh_id.buf.get(0..8) == Some(b"SSH-2.0-") { - // Either the line starts with "SSH-2.0-" - ssh_id.sshid_len = i; - #[allow(clippy::indexing_slicing)] // length checked - return Ok(&ssh_id.buf[..ssh_id.sshid_len]); + if i >= 8 { + // Check if we have a valid SSH protocol identifier + #[allow(clippy::indexing_slicing)] + if let Ok(s) = std::str::from_utf8(&ssh_id.buf[..i]) { + if s.starts_with("SSH-1.99-") || s.starts_with("SSH-2.0-") { + ssh_id.sshid_len = i; + return Ok(ssh_id.id()); + } + } } // Else, it is a "preliminary" (see // https://tools.ietf.org/html/rfc4253#section-4.2), @@ -166,7 +169,7 @@ impl SshRead { ssh_id.total = 0; ssh_id.bytes_read = 0; } - debug!("bytes_read: {:?}", ssh_id.bytes_read); + trace!("bytes_read: {:?}", ssh_id.bytes_read); } } } diff --git a/russh/src/sshbuffer.rs b/russh/src/sshbuffer.rs index a57b0982..feee758f 100644 --- a/russh/src/sshbuffer.rs +++ b/russh/src/sshbuffer.rs @@ -13,8 +13,13 @@ // limitations under the License. // +use core::fmt; use std::num::Wrapping; +use cipher::SealingKey; +use compression::Compress; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + use super::*; /// The SSH client/server identification string. @@ -30,7 +35,7 @@ impl SshId { pub(crate) fn as_kex_hash_bytes(&self) -> &[u8] { match self { Self::Standard(s) => s.as_bytes(), - Self::Raw(s) => s.trim_end_matches(|c| c == '\n' || c == '\r').as_bytes(), + Self::Raw(s) => s.trim_end_matches(['\n', '\r']).as_bytes(), } } @@ -65,8 +70,8 @@ fn test_ssh_id() { #[derive(Debug, Default)] pub struct SSHBuffer { pub buffer: CryptoVec, - pub len: usize, // next packet length. - pub bytes: usize, + pub len: usize, // next packet length. + pub bytes: usize, // total bytes written since the last rekey // Sequence numbers are on 32 bits and wrap. // https://tools.ietf.org/html/rfc4253#section-6.4 pub seqn: Wrapping, @@ -86,3 +91,82 @@ impl SSHBuffer { id.write(&mut self.buffer); } } + +#[derive(Debug)] +pub(crate) struct IncomingSshPacket { + pub buffer: CryptoVec, + pub seqn: Wrapping, +} + +pub(crate) struct PacketWriter { + cipher: Box, + compress: Compress, + compress_buffer: CryptoVec, + write_buffer: SSHBuffer, +} + +impl Debug for PacketWriter { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PacketWriter").finish() + } +} + +impl PacketWriter { + pub fn clear() -> Self { + Self::new(Box::new(cipher::clear::Key {}), Compress::None) + } + + pub fn new(cipher: Box, compress: Compress) -> Self { + Self { + cipher, + compress, + compress_buffer: CryptoVec::new(), + write_buffer: SSHBuffer::new(), + } + } + + pub fn packet_raw(&mut self, buf: &[u8]) -> Result<(), Error> { + if let Some(message_type) = buf.first() { + debug!("> msg type {message_type:?}, len {}", buf.len()); + let packet = self.compress.compress(buf, &mut self.compress_buffer)?; + self.cipher.write(packet, &mut self.write_buffer); + } + Ok(()) + } + + /// Sends and returns the packet contents + pub fn packet Result<(), Error>>( + &mut self, + f: F, + ) -> Result { + let mut buf = CryptoVec::new(); + f(&mut buf)?; + self.packet_raw(&buf)?; + Ok(buf) + } + + pub fn buffer(&mut self) -> &mut SSHBuffer { + &mut self.write_buffer + } + + pub fn compress(&mut self) -> &mut Compress { + &mut self.compress + } + + pub fn set_cipher(&mut self, cipher: Box) { + self.cipher = cipher; + } + + pub fn reset_seqn(&mut self) { + self.write_buffer.seqn = Wrapping(0); + } + + pub async fn flush_into(&mut self, w: &mut W) -> std::io::Result<()> { + if !self.write_buffer.buffer.is_empty() { + w.write_all(&self.write_buffer.buffer).await?; + w.flush().await?; + self.write_buffer.buffer.clear(); + } + Ok(()) + } +} diff --git a/russh/src/tests.rs b/russh/src/tests.rs new file mode 100644 index 00000000..ecff9059 --- /dev/null +++ b/russh/src/tests.rs @@ -0,0 +1,619 @@ +#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] // Allow unwraps, expects and panics in the test suite + +use futures::Future; + +use super::*; + +mod compress { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + use keys::PrivateKeyWithHashAlg; + use log::debug; + use rand_core::OsRng; + use ssh_key::PrivateKey; + + use super::server::{Server as _, Session}; + use super::*; + use crate::server::Msg; + + #[tokio::test] + async fn compress_local_test() { + let _ = env_logger::try_init(); + + let client_key = PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(); + let mut config = server::Config::default(); + config.preferred = Preferred::COMPRESSED; + config.inactivity_timeout = None; // Some(std::time::Duration::from_secs(3)); + config.auth_rejection_time = std::time::Duration::from_secs(3); + config + .keys + .push(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + let config = Arc::new(config); + let mut sh = Server { + clients: Arc::new(Mutex::new(HashMap::new())), + id: 0, + }; + + let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + + tokio::spawn(async move { + let (socket, _) = socket.accept().await.unwrap(); + let server = sh.new_client(socket.peer_addr().ok()); + server::run_stream(config, socket, server).await.unwrap(); + }); + + let mut config = client::Config::default(); + config.preferred = Preferred::COMPRESSED; + let config = Arc::new(config); + + let mut session = client::connect(config, addr, Client {}).await.unwrap(); + let authenticated = session + .authenticate_publickey( + std::env::var("USER").unwrap_or("user".to_owned()), + PrivateKeyWithHashAlg::new( + Arc::new(client_key), + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) + .await + .unwrap() + .success(); + assert!(authenticated); + let mut channel = session.channel_open_session().await.unwrap(); + + let data = &b"Hello, world!"[..]; + channel.data(data).await.unwrap(); + let msg = channel.wait().await.unwrap(); + match msg { + ChannelMsg::Data { data: msg_data } => { + assert_eq!(*data, *msg_data) + } + msg => panic!("Unexpected message {:?}", msg), + } + } + + #[derive(Clone)] + struct Server { + clients: Arc>>, + id: usize, + } + + impl server::Server for Server { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + let s = self.clone(); + self.id += 1; + s + } + } + + impl server::Handler for Server { + type Error = super::Error; + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + { + let mut clients = self.clients.lock().unwrap(); + clients.insert((self.id, channel.id()), session.handle()); + } + Ok(true) + } + async fn auth_publickey( + &mut self, + _: &str, + _: &crate::keys::ssh_key::PublicKey, + ) -> Result { + debug!("auth_publickey"); + Ok(server::Auth::Accept) + } + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + debug!("server data = {:?}", std::str::from_utf8(data)); + session.data(channel, CryptoVec::from_slice(data))?; + Ok(()) + } + } + + struct Client {} + + impl client::Handler for Client { + type Error = super::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &crate::keys::ssh_key::PublicKey, + ) -> Result { + // println!("check_server_key: {:?}", server_public_key); + Ok(true) + } + } +} + +mod channels { + use keys::PrivateKeyWithHashAlg; + use rand_core::OsRng; + use server::Session; + use ssh_key::PrivateKey; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use super::*; + use crate::CryptoVec; + + async fn test_session( + client_handler: CH, + server_handler: SH, + run_client: RC, + run_server: RS, + ) where + RC: FnOnce(crate::client::Handle) -> F1 + Send + Sync + 'static, + RS: FnOnce(crate::server::Handle) -> F2 + Send + Sync + 'static, + F1: Future> + Send + Sync + 'static, + F2: Future + Send + Sync + 'static, + CH: crate::client::Handler + Send + Sync + 'static, + SH: crate::server::Handler + Send + Sync + 'static, + { + use std::sync::Arc; + + use crate::*; + + let _ = env_logger::try_init(); + + let client_key = PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap(); + let mut config = server::Config::default(); + config.inactivity_timeout = None; + config.auth_rejection_time = std::time::Duration::from_secs(3); + config + .keys + .push(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + let config = Arc::new(config); + let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + + let server_join = tokio::spawn(async move { + let (socket, _) = socket.accept().await.unwrap(); + + server::run_stream(config, socket, server_handler) + .await + .map_err(|_| ()) + .unwrap() + }); + + let client_join = tokio::spawn(async move { + let config = Arc::new(client::Config::default()); + let mut session = client::connect(config, addr, client_handler) + .await + .map_err(|_| ()) + .unwrap(); + let authenticated = session + .authenticate_publickey( + std::env::var("USER").unwrap_or("user".to_owned()), + PrivateKeyWithHashAlg::new(Arc::new(client_key), None), + ) + .await + .unwrap(); + assert!(authenticated.success()); + session + }); + + let (server_session, client_session) = tokio::join!(server_join, client_join); + let client_handle = tokio::spawn(run_client(client_session.unwrap())); + let server_handle = tokio::spawn(run_server(server_session.unwrap().handle())); + + let (server_session, client_session) = tokio::join!(server_handle, client_handle); + assert!(server_session.is_ok()); + assert!(client_session.is_ok()); + drop(client_session); + drop(server_session); + } + + #[tokio::test] + async fn test_server_channels() { + #[derive(Debug)] + struct Client {} + + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(true) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut client::Session, + ) -> Result<(), Self::Error> { + assert_eq!(data, &b"hello world!"[..]); + session.data(channel, CryptoVec::from_slice(&b"hey there!"[..]))?; + Ok(()) + } + } + + struct ServerHandle { + did_auth: Option>, + } + + impl ServerHandle { + fn get_auth_waiter(&mut self) -> tokio::sync::oneshot::Receiver<()> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.did_auth = Some(tx); + rx + } + } + + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + async fn auth_succeeded(&mut self, _session: &mut Session) -> Result<(), Self::Error> { + if let Some(a) = self.did_auth.take() { + a.send(()).unwrap(); + } + Ok(()) + } + } + + let mut sh = ServerHandle { did_auth: None }; + let a = sh.get_auth_waiter(); + test_session( + Client {}, + sh, + |c| async move { c }, + |s| async move { + a.await.unwrap(); + let mut ch = s.channel_open_session().await.unwrap(); + ch.data(&b"hello world!"[..]).await.unwrap(); + + let msg = ch.wait().await.unwrap(); + if let ChannelMsg::Data { data } = msg { + assert_eq!(data.as_ref(), &b"hey there!"[..]); + } else { + panic!("Unexpected message {:?}", msg); + } + s + }, + ) + .await; + } + + #[tokio::test] + async fn test_channel_streams() { + #[derive(Debug)] + struct Client {} + + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(true) + } + } + + struct ServerHandle { + channel: Option>>, + } + + impl ServerHandle { + fn get_channel_waiter( + &mut self, + ) -> tokio::sync::oneshot::Receiver> { + let (tx, rx) = tokio::sync::oneshot::channel::>(); + self.channel = Some(tx); + rx + } + } + + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + + async fn channel_open_session( + &mut self, + channel: Channel, + _session: &mut server::Session, + ) -> Result { + if let Some(a) = self.channel.take() { + println!("channel open session {:?}", a); + a.send(channel).unwrap(); + } + Ok(true) + } + } + + let mut sh = ServerHandle { channel: None }; + let scw = sh.get_channel_waiter(); + + test_session( + Client {}, + sh, + |client| async move { + let ch = client.channel_open_session().await.unwrap(); + let mut stream = ch.into_stream(); + stream.write_all(&b"request"[..]).await.unwrap(); + + let mut buf = Vec::new(); + stream.read_buf(&mut buf).await.unwrap(); + assert_eq!(&buf, &b"response"[..]); + + stream.write_all(&b"reply"[..]).await.unwrap(); + + client + }, + |server| async move { + let channel = scw.await.unwrap(); + let mut stream = channel.into_stream(); + + let mut buf = Vec::new(); + stream.read_buf(&mut buf).await.unwrap(); + assert_eq!(&buf, &b"request"[..]); + + stream.write_all(&b"response"[..]).await.unwrap(); + + buf.clear(); + + stream.read_buf(&mut buf).await.unwrap(); + assert_eq!(&buf, &b"reply"[..]); + + server + }, + ) + .await; + } + + #[tokio::test] + async fn test_channel_objects() { + #[derive(Debug)] + struct Client {} + + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(true) + } + } + + struct ServerHandle {} + + impl ServerHandle {} + + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + tokio::spawn(async move { + while let Some(msg) = channel.wait().await { + match msg { + ChannelMsg::Data { data } => { + channel.data(&data[..]).await.unwrap(); + channel.close().await.unwrap(); + break; + } + _ => {} + } + } + }); + Ok(true) + } + } + + let sh = ServerHandle {}; + test_session( + Client {}, + sh, + |c| async move { + let mut ch = c.channel_open_session().await.unwrap(); + ch.data(&b"hello world!"[..]).await.unwrap(); + + let msg = ch.wait().await.unwrap(); + if let ChannelMsg::Data { data } = msg { + assert_eq!(data.as_ref(), &b"hello world!"[..]); + } else { + panic!("Unexpected message {:?}", msg); + } + + assert!(ch.wait().await.is_none()); + c + }, + |s| async move { s }, + ) + .await; + } + + #[tokio::test] + async fn test_channel_window_size() { + #[derive(Debug)] + struct Client {} + + impl client::Handler for Client { + type Error = crate::Error; + + async fn check_server_key( + &mut self, + _server_public_key: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(true) + } + } + + struct ServerHandle { + channel: Option>>, + } + + impl ServerHandle { + fn get_channel_waiter( + &mut self, + ) -> tokio::sync::oneshot::Receiver> { + let (tx, rx) = tokio::sync::oneshot::channel::>(); + self.channel = Some(tx); + rx + } + } + + impl server::Handler for ServerHandle { + type Error = crate::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &crate::keys::ssh_key::PublicKey, + ) -> Result { + Ok(server::Auth::Accept) + } + + async fn channel_open_session( + &mut self, + channel: Channel, + _session: &mut server::Session, + ) -> Result { + if let Some(a) = self.channel.take() { + println!("channel open session {:?}", a); + a.send(channel).unwrap(); + } + Ok(true) + } + } + + let mut sh = ServerHandle { channel: None }; + let scw = sh.get_channel_waiter(); + + test_session( + Client {}, + sh, + |client| async move { + let ch = client.channel_open_session().await.unwrap(); + + let mut writer_1 = ch.make_writer(); + let jh_1 = tokio::spawn(async move { + let buf = [1u8; 1024 * 64]; + assert!(writer_1.write_all(&buf).await.is_ok()); + }); + let mut writer_2 = ch.make_writer(); + let jh_2 = tokio::spawn(async move { + let buf = [2u8; 1024 * 64]; + assert!(writer_2.write_all(&buf).await.is_ok()); + }); + + assert!(tokio::try_join!(jh_1, jh_2).is_ok()); + + client + }, + |server| async move { + let mut channel = scw.await.unwrap(); + + let mut total_data = 2 * 1024 * 64; + while let Some(msg) = channel.wait().await { + match msg { + ChannelMsg::Data { data } => { + total_data -= data.len(); + if total_data == 0 { + break; + } + } + _ => panic!("Unexpected message {:?}", msg), + } + } + + server + }, + ) + .await; + } +} + +mod server_kex_junk { + use std::sync::Arc; + + use tokio::io::AsyncWriteExt; + + use super::server::Server as _; + use super::*; + + #[tokio::test] + async fn server_kex_junk_test() { + let _ = env_logger::try_init(); + + let config = server::Config::default(); + let config = Arc::new(config); + let mut sh = Server {}; + + let socket = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + + tokio::spawn(async move { + let mut client_stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + client_stream + .write_all(b"SSH-2.0-Client_1.0\r\n") + .await + .unwrap(); + // Unexpected message pre-kex + client_stream.write_all(&[0, 0, 0, 2, 0, 99]).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + }); + + let (socket, _) = socket.accept().await.unwrap(); + let server = sh.new_client(socket.peer_addr().ok()); + let rs = server::run_stream(config, socket, server).await.unwrap(); + + // May not panic + assert!(rs.await.is_err()); + } + + #[derive(Clone)] + struct Server {} + + impl server::Server for Server { + type Handler = Self; + fn new_client(&mut self, _: Option) -> Self { + self.clone() + } + } + + impl server::Handler for Server { + type Error = super::Error; + } +} diff --git a/russh/tests/test_backpressure.rs b/russh/tests/test_backpressure.rs new file mode 100644 index 00000000..960d53d0 --- /dev/null +++ b/russh/tests/test_backpressure.rs @@ -0,0 +1,157 @@ +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::Arc; + +use futures::FutureExt; +use rand::RngCore; +use rand_core::OsRng; +use russh::keys::PrivateKeyWithHashAlg; +use russh::server::{self, Auth, Msg, Server as _, Session}; +use russh::{client, Channel, ChannelMsg}; +use ssh_key::PrivateKey; +use tokio::io::AsyncWriteExt; +use tokio::sync::watch; +use tokio::time::sleep; + +pub const WINDOW_SIZE: usize = 8 * 2048; +pub const CHANNEL_BUFFER_SIZE: usize = 10; + +#[tokio::test] +async fn test_backpressure() -> Result<(), anyhow::Error> { + env_logger::init(); + + let addr = addr(); + let data = data(); + let (tx, rx) = watch::channel(()); + + tokio::spawn(Server::run(addr, rx)); + + // Wait until the server is started + while TcpStream::connect(addr).is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + stream(addr, &data, tx).await?; + + Ok(()) +} + +async fn stream(addr: SocketAddr, data: &[u8], tx: watch::Sender<()>) -> Result<(), anyhow::Error> { + let config = Arc::new(client::Config::default()); + let key = Arc::new(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + + let mut session = russh::client::connect(config, addr, Client).await?; + let channel = match session + .authenticate_publickey( + "user", + PrivateKeyWithHashAlg::new( + key, + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) + .await + .map(|x| x.success()) + { + Ok(true) => session.channel_open_session().await?, + Ok(false) => panic!("Authentication failed"), + Err(err) => return Err(err.into()), + }; + + let mut writer = channel.make_writer(); + + // TCP listener will buffer one extra message + for _ in 0..=CHANNEL_BUFFER_SIZE { + assert!(writer.write(data).await.is_ok()); + } + let pending_write = async { writer.write(data).await.unwrap() }; + sleep(std::time::Duration::from_millis(100)).await; + assert_eq!(pending_write.now_or_never(), None); + // Make space on the buffer + tx.send(()).unwrap(); + assert!(writer.write(data).await.is_ok()); + + Ok(()) +} + +fn data() -> Vec { + let mut rng = rand::thread_rng(); + + let mut data = vec![0u8; WINDOW_SIZE]; // Check whether the window_size resizing works + rng.fill_bytes(&mut data); + + data +} + +/// Find a unused local address to bind our server to +fn addr() -> SocketAddr { + TcpListener::bind(("127.0.0.1", 0)) + .unwrap() + .local_addr() + .unwrap() +} + +#[derive(Clone)] +struct Server { + rx: Option>, +} + +impl Server { + async fn run(addr: SocketAddr, rx: watch::Receiver<()>) { + let config = Arc::new(server::Config { + keys: vec![PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()], + window_size: WINDOW_SIZE as u32, + channel_buffer_size: CHANNEL_BUFFER_SIZE, + ..Default::default() + }); + let mut sh = Server { rx: Some(rx) }; + + sh.run_on_address(config, addr).await.unwrap(); + } +} + +impl russh::server::Server for Server { + type Handler = Self; + + fn new_client(&mut self, _: Option) -> Self::Handler { + self.clone() + } +} + +impl russh::server::Handler for Server { + type Error = anyhow::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &ssh_key::PublicKey, + ) -> Result { + Ok(Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + let mut rx = self.rx.take().unwrap(); + tokio::spawn(async move { + while let Ok(_) = rx.changed().await { + match channel.wait().await { + Some(ChannelMsg::Data { .. }) => (), + other => panic!("unexpected message {:?}", other), + } + } + }); + + Ok(true) + } +} + +struct Client; + +impl russh::client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key(&mut self, _: &ssh_key::PublicKey) -> Result { + Ok(true) + } +} diff --git a/russh/tests/test_data_stream.rs b/russh/tests/test_data_stream.rs new file mode 100644 index 00000000..9aec9197 --- /dev/null +++ b/russh/tests/test_data_stream.rs @@ -0,0 +1,226 @@ +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::Arc; + +use rand::RngCore; +use rand_core::OsRng; +use russh::keys::PrivateKeyWithHashAlg; +use russh::server::{self, Auth, Msg, Server as _, Session}; +use russh::{client, Channel, ChannelMsg}; +use ssh_key::PrivateKey; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +pub const WINDOW_SIZE: u32 = 8 * 2048; + +trait ChannelDataCopy { + async fn copy_data_through_channel( + &mut self, + channel: Channel, + data: &[u8], + ) -> anyhow::Result>; +} + +struct ReaderAndWriter; + +impl ChannelDataCopy for ReaderAndWriter { + async fn copy_data_through_channel( + &mut self, + mut channel: Channel, + data: &[u8], + ) -> anyhow::Result> { + let mut buf = Vec::::new(); + let (mut writer, mut reader) = (channel.make_writer_ext(Some(1)), channel.make_reader()); + + let (r0, r1) = tokio::join!( + async { + writer.write_all(data).await?; + writer.shutdown().await?; + + Ok::<_, anyhow::Error>(()) + }, + reader.read_to_end(&mut buf) + ); + + r0?; + let count = r1?; + assert_eq!(data.len(), count); + + Ok(buf) + } +} + +struct ChannelHalves; + +impl ChannelDataCopy for ChannelHalves { + async fn copy_data_through_channel( + &mut self, + channel: Channel, + data: &[u8], + ) -> anyhow::Result> { + let (mut read, write) = channel.split(); + let (r0, r1) = tokio::join!( + async { + write.extended_data(1, data).await?; + write.eof().await?; + + Ok::<_, anyhow::Error>(()) + }, + async { + let mut buf = Vec::::new(); + while let Some(msg) = read.wait().await { + match msg { + ChannelMsg::WindowAdjusted { .. } => {} + ChannelMsg::Data { data } => buf.extend(&*data), + ChannelMsg::Eof => break, + msg => panic!("Got unexpected message: {msg:?}"), + } + } + Ok(buf) + } + ); + + r0?; + r1 + } +} + +#[tokio::test] +async fn test_reader_and_writer() -> Result<(), anyhow::Error> { + run_test(ReaderAndWriter).await +} + +#[tokio::test] +async fn test_channel_halves() -> Result<(), anyhow::Error> { + run_test(ChannelHalves).await +} + +async fn run_test(test: impl ChannelDataCopy) -> Result<(), anyhow::Error> { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(env_logger::init); + + let addr = addr(); + let data = data(); + + tokio::spawn(Server::run(addr)); + + // Wait until the server is started + while TcpStream::connect(addr).is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + stream(addr, &data, test).await?; + + Ok(()) +} + +async fn stream( + addr: SocketAddr, + data: &[u8], + mut test: impl ChannelDataCopy, +) -> Result<(), anyhow::Error> { + let config = Arc::new(client::Config::default()); + let key = Arc::new(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + + let mut session = russh::client::connect(config, addr, Client).await?; + let channel = match session + .authenticate_publickey( + "user", + PrivateKeyWithHashAlg::new( + key, + session.best_supported_rsa_hash().await.unwrap().flatten(), + ), + ) + .await + .map(|x| x.success()) + { + Ok(true) => session.channel_open_session().await?, + Ok(false) => panic!("Authentication failed"), + Err(err) => return Err(err.into()), + }; + + let buf = test.copy_data_through_channel(channel, data).await?; + assert_eq!(data, buf); + + Ok(()) +} + +fn data() -> Vec { + let mut rng = rand::thread_rng(); + + let mut data = vec![0u8; WINDOW_SIZE as usize * 2 + 7]; // Check whether the window_size resizing works + rng.fill_bytes(&mut data); + + data +} + +/// Find a unused local address to bind our server to +fn addr() -> SocketAddr { + TcpListener::bind(("127.0.0.1", 0)) + .unwrap() + .local_addr() + .unwrap() +} + +#[derive(Clone)] +struct Server; + +impl Server { + async fn run(addr: SocketAddr) { + let config = Arc::new(server::Config { + keys: vec![PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()], + window_size: WINDOW_SIZE, + ..Default::default() + }); + let mut sh = Server {}; + + sh.run_on_address(config, addr).await.unwrap(); + } +} + +impl russh::server::Server for Server { + type Handler = Self; + + fn new_client(&mut self, _: Option) -> Self::Handler { + self.clone() + } +} + +impl russh::server::Handler for Server { + type Error = anyhow::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &ssh_key::PublicKey, + ) -> Result { + Ok(Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + tokio::spawn(async move { + let (mut writer, mut reader) = + (channel.make_writer(), channel.make_reader_ext(Some(1))); + + tokio::io::copy(&mut reader, &mut writer) + .await + .expect("Data transfer failed"); + + writer.shutdown().await.expect("Shutdown failed"); + }); + + Ok(true) + } +} + +struct Client; + +impl russh::client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key(&mut self, _: &ssh_key::PublicKey) -> Result { + Ok(true) + } +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f8c2abbb..1de01fa4 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "1.65.0" +channel = "1.81.0"