Skip to content

Commit

Permalink
Merge pull request #117 from elmaxxo/elmaxxo/feat/add-ssl-response-me…
Browse files Browse the repository at this point in the history
…ssage

feat: add SslResponse message
  • Loading branch information
sunng87 authored Oct 22, 2023
2 parents aad10d5 + 6e653e8 commit 6ff434c
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 41 deletions.
10 changes: 10 additions & 0 deletions src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ pub enum PgWireBackendMessage {
ReadyForQuery(response::ReadyForQuery),
ErrorResponse(response::ErrorResponse),
NoticeResponse(response::NoticeResponse),
SslResponse(response::SslResponse),

// data
ParameterDescription(data::ParameterDescription),
Expand Down Expand Up @@ -230,6 +231,7 @@ impl PgWireBackendMessage {
Self::ReadyForQuery(msg) => msg.encode(buf),
Self::ErrorResponse(msg) => msg.encode(buf),
Self::NoticeResponse(msg) => msg.encode(buf),
Self::SslResponse(msg) => msg.encode(buf),

Self::ParameterDescription(msg) => msg.encode(buf),
Self::RowDescription(msg) => msg.encode(buf),
Expand Down Expand Up @@ -522,6 +524,14 @@ mod test {
roundtrip!(sslreq, SslRequest);
}

#[test]
fn test_sslresponse() {
let sslaccept = SslResponse::Accept;
roundtrip!(sslaccept, SslResponse);
let sslrefuse = SslResponse::Refuse;
roundtrip!(sslrefuse, SslResponse);
}

#[test]
fn test_saslresponse() {
let saslinitialresp =
Expand Down
57 changes: 57 additions & 0 deletions src/messages/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,60 @@ impl Message for NoticeResponse {
}
}
}

/// Response to SSLRequest.
/// To initiate an SSL-encrypted connection, the frontend initially sends an SSLRequest
/// message rather than a StartupMessage. The server then responds with a single byte
/// containing 'S' or 'N', indicating that it is willing or unwilling to perform SSL, respectively.
#[derive(Debug, PartialEq)]
pub enum SslResponse {
Accept,
Refuse,
}

impl SslResponse {
pub const BYTE_ACCEPT: u8 = b'S';
pub const BYTE_REFUSE: u8 = b'N';
// The whole message takes only one byte and has no size field.
pub const MESSAGE_LENGTH: usize = 1;
}

impl Message for SslResponse {
fn message_length(&self) -> usize {
Self::MESSAGE_LENGTH
}

fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
match self {
Self::Accept => buf.put_u8(Self::BYTE_ACCEPT),
Self::Refuse => buf.put_u8(Self::BYTE_REFUSE),
}
Ok(())
}

fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
self.encode_body(buf)
}

fn decode_body(_: &mut BytesMut, _: usize) -> PgWireResult<Self> {
unreachable!()
}

fn decode(buf: &mut BytesMut) -> PgWireResult<Option<Self>> {
if buf.remaining() >= Self::MESSAGE_LENGTH {
match buf[0] {
Self::BYTE_ACCEPT => {
buf.advance(Self::MESSAGE_LENGTH);
Ok(Some(SslResponse::Accept))
}
Self::BYTE_REFUSE => {
buf.advance(Self::MESSAGE_LENGTH);
Ok(Some(SslResponse::Refuse))
}
_ => Ok(None),
}
} else {
Ok(None)
}
}
}
94 changes: 53 additions & 41 deletions src/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::io::Error as IOError;
use std::sync::Arc;

use bytes::Buf;
use bytes::BytesMut;
use futures::future::poll_fn;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::codec::{Decoder, Encoder, Framed};
Expand All @@ -15,7 +15,7 @@ use crate::api::query::SimpleQueryHandler;
use crate::api::{ClientInfo, ClientInfoHolder, PgWireConnectionState};
use crate::error::{ErrorInfo, PgWireError, PgWireResult};
use crate::messages::response::ReadyForQuery;
use crate::messages::response::READY_STATUS_IDLE;
use crate::messages::response::{SslResponse, READY_STATUS_IDLE};
use crate::messages::startup::{SslRequest, Startup};
use crate::messages::{Message, PgWireBackendMessage, PgWireFrontendMessage};

Expand All @@ -32,11 +32,15 @@ impl Decoder for PgWireMessageServerCodec {
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.client_info.state() {
PgWireConnectionState::AwaitingStartup => {
if let Some(request) = SslRequest::decode(src)? {
return Ok(Some(PgWireFrontendMessage::SslRequest(request)));
}

if let Some(startup) = Startup::decode(src)? {
Ok(Some(PgWireFrontendMessage::Startup(startup)))
} else {
Ok(None)
return Ok(Some(PgWireFrontendMessage::Startup(startup)));
}

Ok(None)
}
_ => PgWireFrontendMessage::decode(src),
}
Expand Down Expand Up @@ -176,45 +180,45 @@ where
Ok(())
}

async fn peek_for_sslrequest(
tcp_socket: &mut TcpStream,
ssl_supported: bool,
) -> Result<bool, IOError> {
let mut ssl = false;
async fn is_sslrequest_pending(tcp_socket: &TcpStream) -> Result<bool, IOError> {
let mut buf = [0u8; SslRequest::BODY_SIZE];
let mut buf = ReadBuf::new(&mut buf);
loop {
let size = poll_fn(|cx| tcp_socket.poll_peek(cx, &mut buf)).await?;
if size == 0 {
while buf.filled().len() < SslRequest::BODY_SIZE {
if poll_fn(|cx| tcp_socket.poll_peek(cx, &mut buf)).await? == 0 {
// the tcp_stream has ended
return Ok(false);
}
if size == SslRequest::BODY_SIZE {
let mut buf_ref = buf.filled();
// skip first 4 bytes
buf_ref.get_i32();
if buf_ref.get_i32() == SslRequest::BODY_MAGIC_NUMBER {
// the socket is sending sslrequest, read the first 8 bytes
// skip first 8 bytes
tcp_socket
.read_exact(&mut [0u8; SslRequest::BODY_SIZE])
.await?;
// ssl configured
if ssl_supported {
ssl = true;
tcp_socket.write_all(b"S").await?;
} else {
tcp_socket.write_all(b"N").await?;
}
}
}

return Ok(ssl);
}
let mut buf = BytesMut::from(buf.filled());
if let Ok(Some(_)) = SslRequest::decode(&mut buf) {
return Ok(true);
}
Ok(false)
}

async fn peek_for_sslrequest(
socket: &mut Framed<TcpStream, PgWireMessageServerCodec>,
ssl_supported: bool,
) -> Result<bool, IOError> {
let mut ssl = false;
if is_sslrequest_pending(socket.get_ref()).await? {
// consume request
socket.next().await;

let response = if ssl_supported {
ssl = true;
PgWireBackendMessage::SslResponse(SslResponse::Accept)
} else {
PgWireBackendMessage::SslResponse(SslResponse::Refuse)
};
socket.send(response).await?;
}
Ok(ssl)
}

pub async fn process_socket<A, Q, EQ>(
mut tcp_socket: TcpStream,
tcp_socket: TcpStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
startup_handler: Arc<A>,
query_handler: Arc<Q>,
Expand All @@ -227,13 +231,14 @@ where
{
let addr = tcp_socket.peer_addr()?;
tcp_socket.set_nodelay(true)?;

let client_info = ClientInfoHolder::new(addr, false);
let mut tcp_socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info));
let ssl = peek_for_sslrequest(&mut tcp_socket, tls_acceptor.is_some()).await?;

let client_info = ClientInfoHolder::new(addr, ssl);
if ssl {
// safe to unwrap tls_acceptor here
let ssl_socket = tls_acceptor.unwrap().accept(tcp_socket).await?;
let mut socket = Framed::new(ssl_socket, PgWireMessageServerCodec::new(client_info));
if !ssl {
// use an already configured socket.
let mut socket = tcp_socket;

while let Some(Ok(msg)) = socket.next().await {
if let Err(e) = process_message(
Expand All @@ -249,7 +254,14 @@ where
}
}
} else {
let mut socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info));
// mention the use of ssl
let client_info = ClientInfoHolder::new(addr, true);
// safe to unwrap tls_acceptor here
let ssl_socket = tls_acceptor
.unwrap()
.accept(tcp_socket.into_inner())
.await?;
let mut socket = Framed::new(ssl_socket, PgWireMessageServerCodec::new(client_info));

while let Some(Ok(msg)) = socket.next().await {
if let Err(e) = process_message(
Expand Down

0 comments on commit 6ff434c

Please sign in to comment.