Skip to content

Commit

Permalink
Merge pull request #393 from str4d/368-async-buffered-header-reading
Browse files Browse the repository at this point in the history
Enable header parsing to use `R: futures::io::AsyncBufRead`
  • Loading branch information
str4d authored Jun 12, 2023
2 parents 623f663 + 37012ba commit 22882cb
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 86 deletions.
6 changes: 4 additions & 2 deletions age/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ to 1.0.0 are beta releases.

## [Unreleased]
### Added
- `age::Decryptor::new_buffered`, which is more efficient for types implementing
`std::io::BufRead` (which includes `&[u8]` slices).
- `age::Decryptor::{new_buffered, new_async_buffered}`, which are more efficient
for types implementing `std::io::BufRead` or `futures::io::AsyncBufRead`
(which includes `&[u8]` slices).
- `impl std::io::BufRead for age::armor::ArmoredReader`
- `impl futures::io::AsyncBufRead for age::armor::ArmoredReader`

## [0.9.1] - 2022-03-24
### Added
Expand Down
36 changes: 35 additions & 1 deletion age/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use crate::{
};

#[cfg(feature = "async")]
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
};

const AGE_MAGIC: &[u8] = b"age-encryption.org/";
const V1_MAGIC: &[u8] = b"v1";
Expand Down Expand Up @@ -146,6 +148,38 @@ impl Header {
}
}

#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub(crate) async fn read_async_buffered<R: AsyncBufRead + Unpin>(
mut input: R,
) -> Result<Self, DecryptError> {
let mut data = vec![];
loop {
match read::header(&data) {
Ok((_, mut header)) => {
if let Header::V1(h) = &mut header {
h.encoded_bytes = Some(data);
}
break Ok(header);
}
Err(nom::Err::Incomplete(nom::Needed::Size(_))) => {
// As we have a buffered reader, we can leverage the fact that the
// currently-defined header formats are newline-separated, to more
// efficiently read data for the parser to consume.
if input.read_until(b'\n', &mut data).await? == 0 {
break Err(DecryptError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Incomplete header",
)));
}
}
Err(_) => {
break Err(DecryptError::InvalidHeader);
}
}
}
}

pub(crate) fn write<W: Write>(&self, mut output: W) -> io::Result<()> {
cookie_factory::gen(write::header(self), &mut output)
.map(|_| ())
Expand Down
204 changes: 122 additions & 82 deletions age/src/primitives/armor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use futures::{
#[cfg(feature = "async")]
use std::mem;
#[cfg(feature = "async")]
use std::ops::DerefMut;
#[cfg(feature = "async")]
use std::pin::Pin;
#[cfg(feature = "async")]
use std::str;
Expand Down Expand Up @@ -722,28 +724,6 @@ impl<R> ArmoredReader<R> {
Ok(())
}

/// Reads cached data into the given buffer.
///
/// Returns the number of bytes read into the buffer, or None if there was no cached
/// data.
#[cfg(feature = "async")]
fn read_cached_data(&mut self, buf: &mut [u8]) -> Option<usize> {
if self.byte_start >= self.byte_end {
None
} else if self.byte_start + buf.len() <= self.byte_end {
buf.copy_from_slice(&self.byte_buf[self.byte_start..self.byte_start + buf.len()]);
self.byte_start += buf.len();
self.data_read += buf.len();
Some(buf.len())
} else {
let to_read = self.byte_end - self.byte_start;
buf[..to_read].copy_from_slice(&self.byte_buf[self.byte_start..self.byte_end]);
self.byte_start += to_read;
self.data_read += to_read;
Some(to_read)
}
}

/// Validates `self.line_buf` and parses it into `self.byte_buf`.
///
/// Returns `true` if this was the last line.
Expand Down Expand Up @@ -985,12 +965,8 @@ fn read_line_internal<R: AsyncBufRead + ?Sized>(

#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
impl<R: AsyncBufRead + Unpin> AsyncRead for ArmoredReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, Error>> {
impl<R: AsyncBufRead + Unpin> AsyncBufRead for ArmoredReader<R> {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
loop {
match self.is_armored {
None => {
Expand All @@ -1006,71 +982,135 @@ impl<R: AsyncBufRead + Unpin> AsyncRead for ArmoredReader<R> {
self.detect_armor()?
}
Some(false) => {
// Return any leftover data from armor detection.
return if let Some(read) = self.read_cached_data(buf) {
Poll::Ready(Ok(read))
let ret = if self.byte_start >= self.byte_end {
let this = self.as_mut().project();
let read = ready!(this.inner.poll_read(cx, &mut this.byte_buf[..]))?;
(*this.byte_start) = 0;
(*this.byte_end) = read;
self.count_reader_bytes(read);
&self.get_mut().byte_buf[..read]
} else {
self.as_mut().project().inner.poll_read(cx, buf).map(|res| {
res.map(|read| {
self.data_read += read;
self.count_reader_bytes(read)
})
})
let this = self.get_mut();
&this.byte_buf[this.byte_start..this.byte_end]
};
break Poll::Ready(Ok(ret));
}
Some(true) if self.found_end => return Poll::Ready(Ok(0)),
Some(true) if self.found_end => return Poll::Ready(Ok(&[])),
Some(true) => {
// Output any remaining bytes from the previous line
if let Some(read) = self.read_cached_data(buf) {
return Poll::Ready(Ok(read));
}

// Read the next line
{
// Emulates `AsyncBufReadExt::read_line`.
let mut this = self.as_mut().project();
let buf: &mut String = this.line_buf;
let mut bytes = mem::take(buf).into_bytes();
let mut read = 0;
ready!(read_line_internal(
this.inner.as_mut(),
cx,
buf,
&mut bytes,
&mut read,
))
}
.map(|read| self.count_reader_bytes(read))?;

// Parse the line into bytes.
let read = if self.parse_armor_line()? {
// This was the last line! Check for trailing garbage.
let mut this = self.as_mut().project();
loop {
let amt = match ready!(this.inner.as_mut().poll_fill_buf(cx))? {
&[] => break,
buf => {
if buf.iter().any(|b| !b.is_ascii_whitespace()) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
ArmoredReadError::TrailingGarbage,
)));
}
buf.len()
}
};
this.inner.as_mut().consume(amt);
let ret = if self.byte_start >= self.byte_end {
let last =
ready!(Pin::new(self.deref_mut()).poll_read_next_armor_line(cx))?;
if last {
&[]
} else {
let this = self.get_mut();
&this.byte_buf[this.byte_start..this.byte_end]
}
0
} else {
// Output as much as we can of this line.
self.read_cached_data(buf).unwrap_or(0)
let this = self.get_mut();
&this.byte_buf[this.byte_start..this.byte_end]
};
break Poll::Ready(Ok(ret));
}
}
}
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
let this = self.as_mut().project();
(*this.byte_start) += amt;
(*this.data_read) += amt;
assert!(this.byte_start <= this.byte_end);
}
}

#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
impl<R: AsyncBufRead + Unpin> ArmoredReader<R> {
/// Fills `self.byte_buf` with the next line of armored data.
///
/// Returns `true` if this was the last line.
fn poll_read_next_armor_line(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<bool>> {
assert_eq!(self.is_armored, Some(true));

// Read the next line
{
// Emulates `AsyncBufReadExt::read_line`.
let mut this = self.as_mut().project();
let buf: &mut String = this.line_buf;
let mut bytes = mem::take(buf).into_bytes();
let mut read = 0;
ready!(read_line_internal(
this.inner.as_mut(),
cx,
buf,
&mut bytes,
&mut read,
))
}
.map(|read| self.count_reader_bytes(read))?;

// Parse the line into bytes.
if self.parse_armor_line()? {
// This was the last line! Check for trailing garbage.
let mut this = self.as_mut().project();
loop {
let amt = match ready!(this.inner.as_mut().poll_fill_buf(cx))? {
&[] => break,
buf => {
if buf.iter().any(|b| !b.is_ascii_whitespace()) {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
ArmoredReadError::TrailingGarbage,
)));
}
buf.len()
}
};
this.inner.as_mut().consume(amt);
}
Poll::Ready(Ok(true))
} else {
Poll::Ready(Ok(false))
}
}
}

return Poll::Ready(Ok(read));
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
impl<R: AsyncBufRead + Unpin> AsyncRead for ArmoredReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
mut buf: &mut [u8],
) -> Poll<Result<usize, Error>> {
let buf_len = buf.len();

while !buf.is_empty() {
match Pin::new(self.deref_mut()).poll_fill_buf(cx) {
Poll::Pending if buf_len > buf.len() => break,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok([])) => break,
Poll::Ready(Ok(next)) => {
let read = cmp::min(next.len(), buf.len());

if next.len() < buf.len() {
buf[..read].copy_from_slice(next);
} else {
buf.copy_from_slice(&next[..read]);
}

Pin::new(self.deref_mut()).consume(read);
buf = &mut buf[read..];
}
}
}

Poll::Ready(Ok(buf_len - buf.len()))
}
}

Expand Down
33 changes: 32 additions & 1 deletion age/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
};

#[cfg(feature = "async")]
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

pub mod decryptor;

Expand Down Expand Up @@ -229,6 +229,13 @@ impl<R: AsyncRead + Unpin> Decryptor<R> {
/// Attempts to create a decryptor for an age file.
///
/// Returns an error if the input does not contain a valid age file.
///
/// # Performance
///
/// This constructor will work with any type implementing [`AsyncRead`], and uses a
/// slower parser and internal buffering to ensure no overreading occurs. Consider
/// using [`Decryptor::new_async_buffered`] for types implementing [`AsyncBufRead`],
/// which includes `&[u8]` slices.
pub async fn new_async(mut input: R) -> Result<Self, DecryptError> {
let header = Header::read_async(&mut input).await?;

Expand All @@ -242,6 +249,30 @@ impl<R: AsyncRead + Unpin> Decryptor<R> {
}
}

#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
impl<R: AsyncBufRead + Unpin> Decryptor<R> {
/// Attempts to create a decryptor for an age file.
///
/// Returns an error if the input does not contain a valid age file.
///
/// # Performance
///
/// This constructor is more performant than [`Decryptor::new_async`] for types
/// implementing [`AsyncBufRead`], which includes `&[u8]` slices.
pub async fn new_async_buffered(mut input: R) -> Result<Self, DecryptError> {
let header = Header::read_async_buffered(&mut input).await?;

match header {
Header::V1(v1_header) => {
let nonce = Nonce::read_async(&mut input).await?;
Decryptor::from_v1_header(input, v1_header, nonce)
}
Header::Unknown(_) => Err(DecryptError::UnknownFormat),
}
}
}

#[cfg(test)]
mod tests {
use age_core::secrecy::SecretString;
Expand Down
Loading

0 comments on commit 22882cb

Please sign in to comment.