From ab8d48e400f4cd77f96a6f8dc652cfe6003fdb9c Mon Sep 17 00:00:00 2001 From: Aeledfyr Date: Tue, 29 Apr 2025 22:43:01 -0500 Subject: [PATCH 1/2] Implement `string.format` - Almost entirely compatible with PRLua under the C locale - Minor differences in the handling of '%q'; - Invalid UTF8 is escaped rather than passed through - math.mininteger is represented as `(-9223372036854775807-1)` rather than `0x8000000000000000` due to parsing differences in Piccolo. - Supports a few extra specifiers: - %b and %B for binary integers (C23) - %F for uppercase floats - %C for encoding a unicode codepoint as utf8 - %S for utf8-aware strings --- src/stdlib/string.rs | 13 + src/stdlib/string/format.rs | 829 +++++++++++++++++++++++++ src/stdlib/string/format/float.rs | 303 +++++++++ src/stdlib/string/format/parse.rs | 94 +++ tests/scripts-wishlist/num-formats.lua | 28 + tests/scripts/format.lua | 313 ++++++++++ 6 files changed, 1580 insertions(+) create mode 100644 src/stdlib/string/format.rs create mode 100644 src/stdlib/string/format/float.rs create mode 100644 src/stdlib/string/format/parse.rs create mode 100644 tests/scripts-wishlist/num-formats.lua create mode 100644 tests/scripts/format.lua diff --git a/src/stdlib/string.rs b/src/stdlib/string.rs index 04deba71..603b99b4 100644 --- a/src/stdlib/string.rs +++ b/src/stdlib/string.rs @@ -1,5 +1,7 @@ use crate::{Callback, CallbackReturn, Context, String, Table}; +mod format; + pub fn load_string<'gc>(ctx: Context<'gc>) { let string = Table::new(&ctx); @@ -99,5 +101,16 @@ pub fn load_string<'gc>(ctx: Context<'gc>) { }), ); + string + .set( + ctx, + "format", + Callback::from_fn(&ctx, |ctx, _, stack| { + let seq = format::string_format(ctx, stack)?; + Ok(CallbackReturn::Sequence(crate::BoxSequence::new(&ctx, seq))) + }), + ) + .unwrap(); + ctx.set_global("string", string); } diff --git a/src/stdlib/string/format.rs b/src/stdlib/string/format.rs new file mode 100644 index 00000000..c2bf06c4 --- /dev/null +++ b/src/stdlib/string/format.rs @@ -0,0 +1,829 @@ +//! An implementation of C's `sprintf` / Lua's `string.format` +//! +//! References: +//! - [Lua 5.4 Manual on `string.format`](https://www.lua.org/manual/5.4/manual.html#pdf-string.format) +//! - [glibc manual 12.12: Formatted Output](https://www.gnu.org/software/libc/manual/html_node/Formatted-Output.html) +//! - [Documentation of specific meaning of `%g`](https://stackoverflow.com/a/54162153) +//! - [Python's formatting docs](https://docs.python.org/3/library/string.html#format-specification-mini-language) +//! +//! Specifier syntax: `"%" [flags] [width] ["." precision] spec` +//! +//! Supported specifiers: +//! - `%%` - a literal `%` character +//! - `%c` - output a raw byte (modulo 256 for integers larger than 255) +//! - `%C` - format a unicode code point as utf8 +//! - `%s` - string (prints raw bytes, does not reinterpret as utf8) +//! - `%S` - utf8 string; width and precision are in terms of codepoints +//! - `%d`, `%i` - signed integer +//! - `%u` - unsigned integer (converted to 64 bit signed integer, then interpreted as unsigned) +//! - `%o` - usigned octal integer +//! - `%x`, `%X` - unsigned hexidecimal integer +//! - `%b`, `%B` - unsigned binary integer +//! - `%f`, `%F` - normal form floating point +//! - `%g`, `%G` - compact floating point +//! - `%e`, `%E` - exponential form floating point +//! - `%a`, `%A` - hexidecimal floating point +//! - `%p` - format a value as a pointer, for non-literal values +//! - `%q` - format a value as an escaped Lua literal; supports +//! `nil`, `bool`, `string`, `integer`, or `float` (formatted as a hex float) +//! +//! Supported flags: +//! - `-`: left align +//! - `0`: zero pad +//! - ` ` (space): include space in sign position for positive numbers +//! - `+`: include sign for positive numbers, overriding space if both are specified +//! - `#`: alternate mode +//! - On floats, preserve a trailing decimal point +//! - On hex/octal/binary integers, prefix with the format +//! (`0x`, `0`, and `0b`, respectively) +//! +//! Width and precision are supported, but are limited to `ARG_MAX` (99). +//! - Width: specify the minimum width to pad to +//! - Ignored on `%q` +//! - Precision: +//! - For `%s`, truncates the string to the specified length +//! - For integer specs, zero-pads the number to the specified length +//! (may differ from `width`, which is still padded with spaces) +//! - For floats, specifies the number of digits of precision to use +//! - This implementation supports using `*` to read width/precision +//! from the argument list. The argument is converted to an integer; +//! for `width`, if the argument is negative, the value will be left +//! aligned, and use `abs(arg)` as the width. +//! +//! Compatibility notes: +//! - This should match output of PRLua's `string.format` / POSIX sprintf +//! in the vast majority of cases, but there will be differences. +//! - This implementation is not locale-aware, and assumes LC_ALL=C. +//! - Floating point formatting may differ slightly: +//! - `%f` specifier does not support `#` to require a trailing decimal +//! point, due to implementation limitations +//! - formatting of subnormal numbers has not been thoroughly tested, +//! may have rounding errors. +//! - PRLua does not support `%F` (uppercase float; only differs for inf/nan) +//! - PRLua does not support the C23 `%b`/`%B` (binary unsigned int) specifiers +//! - PRLua's `%q` represents `math.mininteger` as `0x8000000000000000`, but +//! piccolo represents it as `(-9223372036854775807-1)` +//! - PRLua's `%q` passes any byte above 127 through as a raw byte; this +//! implementation passes through valid UTF-8 codepoints, but escapes +//! other bytes. +//! - (Matching PRLua) No support for C style value length specifiers. +//! (such `%lld` for `u64`s) +//! - (Matching PRLua) No support for `%n` (length write-back) +//! - Supports `%C` to format a unicode code point as UTF-8. +//! (in C, this is either `%C` or `%lc`, if the locale supports it) +//! - Supports `%S`, a unicode-aware variant of `%s`; width and precision +//! are specified in terms of codepoints rather than bytes. + +use core::{char, pin::Pin}; +use std::io::Write; + +use gc_arena::{Collect, Gc}; +use thiserror::Error; + +use crate::meta_ops; +use crate::{Context, Error, Execution, FromValue, Function, Sequence, SequencePoll, Stack, Value}; + +mod float; +mod parse; + +use float::FloatMode; + +const FMT_SPEC: u8 = b'%'; +const ARG_MAX: u32 = 99; + +#[derive(Debug, Error)] +enum FormatError { + #[error("invalid format specifier {:?}", *.0 as char)] + BadSpec(u8), + #[error("invalid format specifier; precision is limited to {}", ARG_MAX)] + BadPrecision, + #[error("invalid format specifier; width is limited to {}", ARG_MAX)] + BadWidth, + #[error("invalid format specifier; flag {:?} is not supported for {}", *.1, *.0 as char)] + BadFlag(u8, Flags), + #[error("missing value for format specifier {:?}", *.0 as char)] + MissingValue(u8), + #[error("value of wrong type for format specifier {:?}; expected {}, found {}", *.0 as char, .1, .2)] + BadValueType(u8, &'static str, &'static str), + #[error("value out of range for format specifier {:?}", *.0 as char)] + ValueOutOfRange(u8), + #[error("weird floating point value?")] + BadFloat, + #[error("Non-utf8 string passed to specifier {:?}", *.0 as char)] + NonUnicodeString(u8), +} + +#[derive(Default, Copy, Clone)] +pub struct Flags(u8); + +impl Flags { + const ALTERNATE: Self = Self(1 << 0); + const LEFT_ALIGN: Self = Self(1 << 1); + const ZERO_PAD: Self = Self(1 << 2); + const SIGN_FORCE: Self = Self(1 << 3); + const SIGN_SPACE: Self = Self(1 << 4); + const WIDTH: Self = Self(1 << 5); + const PRECISION: Self = Self(1 << 6); + + const NONE: Self = Self(0); + const ALL: Self = Self(0b01111111); + + const UINT_ALLOWED: Self = + Self(Self::LEFT_ALIGN.0 | Self::ZERO_PAD.0 | Self::WIDTH.0 | Self::PRECISION.0); + const SINT_ALLOWED: Self = Self(Self::UINT_ALLOWED.0 | Self::SIGN_FORCE.0 | Self::SIGN_SPACE.0); +} + +impl Flags { + fn has(self, flag: Flags) -> bool { + self.0 & flag.0 == flag.0 + } +} + +impl core::ops::BitOr for Flags { + type Output = Self; + fn bitor(self, rhs: Self) -> Self { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitOrAssign for Flags { + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } +} + +impl core::fmt::Debug for Flags { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut first = true; + write!(f, "Flags(")?; + for (flag, name) in [ + (Self::ALTERNATE, "ALTERNATE"), + (Self::LEFT_ALIGN, "LEFT_ALIGN"), + (Self::ZERO_PAD, "ZERO_PAD"), + (Self::SIGN_FORCE, "SIGN_FORCE"), + (Self::SIGN_SPACE, "SIGN_SPACE"), + (Self::WIDTH, "WIDTH"), + (Self::PRECISION, "PRECISION"), + ] { + if self.has(flag) { + if first { + first = false; + write!(f, "{}", name)?; + } else { + write!(f, " | {}", name)?; + } + } + } + write!(f, ")")?; + Ok(()) + } +} + +// Note: if width is specified by a argument, it will be interpreted +// as its absolute value, setting the left align flag if negative. +#[derive(Copy, Clone)] +struct FormatSpecifier { + spec: u8, + flags: Flags, + width: OptionalArg, + precision: OptionalArg, +} + +#[derive(Copy, Clone)] +enum OptionalArg { + None, + Arg, + Specified(u32), +} + +#[derive(Default, Clone, Copy)] +struct FormatArgs { + width: usize, + precision: Option, + left_align: bool, + zero_pad: bool, + alternate: bool, + upper: bool, + flags: Flags, +} + +impl FormatSpecifier { + fn check_flags(&self, allowed: Flags) -> Result<(), FormatError> { + let leftover = self.flags.0 & !allowed.0; + if leftover != 0 { + Err(FormatError::BadFlag(self.spec, Flags(leftover))) + } else { + Ok(()) + } + } + + fn get_arg<'gc>( + &self, + arg: OptionalArg, + values: &mut impl Iterator>, + ) -> Result<(Option, bool), FormatError> { + match arg { + OptionalArg::None => Ok((None, false)), + OptionalArg::Arg => { + let int = self.next_int(values)?; + let negative = int < 0; + let abs = int.unsigned_abs(); + if abs > ARG_MAX as u64 { + return Err(FormatError::ValueOutOfRange(self.spec)); + } + Ok((Some(abs as usize), negative)) + } + OptionalArg::Specified(val) => Ok((Some(val as usize), false)), + } + } + + fn common_args<'gc>( + &self, + values: &mut impl Iterator>, + ) -> Result { + let (width, width_neg) = self.get_arg(self.width, values)?; + let (precision, _) = self.get_arg(self.precision, values)?; + Ok(FormatArgs { + width: width.unwrap_or(0), + precision, + left_align: self.flags.has(Flags::LEFT_ALIGN) || width_neg, + zero_pad: self.flags.has(Flags::ZERO_PAD) + && !(self.flags.has(Flags::LEFT_ALIGN) || width_neg), + alternate: self.flags.has(Flags::ALTERNATE), + upper: self.spec.is_ascii_uppercase(), + flags: self.flags, + }) + } + + fn next_value<'gc>( + &self, + values: &mut impl Iterator>, + ) -> Result, FormatError> { + values.next().ok_or(FormatError::MissingValue(self.spec)) + } + + fn next_int<'gc>( + &self, + values: &mut impl Iterator>, + ) -> Result { + let val = self.next_value(values)?; + let int = val + .to_integer() + .ok_or_else(|| FormatError::BadValueType(self.spec, "integer", val.type_name()))?; + Ok(int) + } + + fn next_float<'gc>( + &self, + values: &mut impl Iterator>, + ) -> Result { + let val = self.next_value(values)?; + let float = val + .to_number() + .ok_or_else(|| FormatError::BadValueType(self.spec, "number", val.type_name()))?; + Ok(float) + } +} + +impl FormatArgs { + fn sign_char(&self, negative: bool) -> &'static [u8] { + if negative { + b"-" + } else if self.flags.has(Flags::SIGN_FORCE) { + b"+" + } else if self.flags.has(Flags::SIGN_SPACE) { + b" " + } else { + b"" + } + } + + /// Returns the width to which an integer should be zero-padded to, + /// if zero-padding is requested. + fn integer_zeroed_width(&self, prefix: &[u8]) -> usize { + if let Some(p) = self.precision { + p + } else if self.zero_pad { + self.width.saturating_sub(prefix.len()) + } else { + 0 + } + } + + /// Write the initial space padding, prefix (sign) and zero padding + /// for a format specifier. This returns a [`PadScope`], which must + /// be used to finish the trailing padding by calling [`PadScope::finish_pad`]. + fn pad_num_before( + &self, + w: &mut W, + len: usize, + zeroed_width: usize, + prefix: &[u8], + ) -> Result { + // right: [ ][-][0000][nnnn] + // left: [-][0000][nnnn][ ] + let zero_padding = zeroed_width.saturating_sub(len); + let space_padding = self.width.saturating_sub(zero_padding + prefix.len() + len); + if space_padding > 0 && !self.left_align { + write_padding(w, b' ', space_padding)?; + } + if !prefix.is_empty() { + w.write_all(prefix)?; + } + if zero_padding > 0 { + write_padding(w, b'0', zero_padding)?; + } + let trailing_padding = if self.left_align { space_padding } else { 0 }; + Ok(PadScope { trailing_padding }) + } +} + +#[must_use] +struct PadScope { + trailing_padding: usize, +} + +impl PadScope { + fn finish_pad(self, w: &mut W) -> Result<(), std::io::Error> { + if self.trailing_padding > 0 { + write_padding(w, b' ', self.trailing_padding)?; + } + Ok(()) + } +} + +/// Write the given byte repeated `count` times. +fn write_padding(w: &mut W, byte: u8, count: usize) -> Result<(), std::io::Error> { + let buf = [byte; 16]; + let mut remaining = count; + while remaining > 0 { + match w.write(&buf[..remaining.min(buf.len())]) { + Ok(n) => remaining -= n, + Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue, + Err(e) => return Err(e), + } + } + Ok(()) +} + +fn integer_length(i: u64) -> usize { + 1 + i.checked_ilog10().unwrap_or(0) as usize +} +fn integer_length_hex(i: u64) -> usize { + 1 + i.checked_ilog2().unwrap_or(0) as usize / 4 +} +fn integer_length_octal(i: u64) -> usize { + 1 + i.checked_ilog2().unwrap_or(0) as usize / 3 +} +fn integer_length_binary(i: u64) -> usize { + 1 + i.checked_ilog2().unwrap_or(0) as usize +} + +fn memchr(needle: u8, haystack: &[u8]) -> Option { + haystack.iter().position(|&b| b == needle) +} + +pub fn string_format<'gc>( + ctx: Context<'gc>, + stack: Stack<'gc, '_>, +) -> Result, Error<'gc>> { + let str = crate::string::String::from_value(ctx, stack.get(0))?; + Ok(FormatState { + buf: Vec::new(), + arg_count: stack.len(), + str, + index: 0, + value_index: 1, + inner: FormatStateInner::Start, + }) +} + +#[derive(Collect)] +#[collect(no_drop)] +struct FormatState<'gc> { + buf: Vec, + arg_count: usize, + str: crate::string::String<'gc>, + index: usize, + value_index: usize, + #[collect(require_static)] + inner: FormatStateInner, +} + +enum FormatStateInner { + Start, + EvaluateCallback { + spec: FormatSpecifier, + dest: EvalContinuation, + }, + End, +} + +enum EvalPoll<'gc> { + Done, + PassValue { + value: Value<'gc>, + then: EvalContinuation, + }, + Call { + call: meta_ops::MetaCall<'gc, 1>, + then: EvalContinuation, + }, +} + +#[derive(Copy, Clone)] +enum EvalContinuation { + Init, + ToStringResult(FormatArgs), + UnicodeToStringResult(FormatArgs), +} + +impl<'gc> Sequence<'gc> for FormatState<'gc> { + fn poll( + self: Pin<&mut Self>, + ctx: Context<'gc>, + _exec: Execution<'gc, '_>, + stack: Stack<'gc, '_>, + ) -> Result, Error<'gc>> { + step(ctx, self.get_mut(), stack) + } +} + +impl<'gc> EvalPoll<'gc> { + fn from_metaresult(res: meta_ops::MetaResult<'gc, 1>, then: EvalContinuation) -> Self { + match res { + meta_ops::MetaResult::Value(value) => EvalPoll::PassValue { value, then }, + meta_ops::MetaResult::Call(call) => EvalPoll::Call { call, then }, + } + } +} + +fn step<'gc>( + ctx: Context<'gc>, + state: &mut FormatState<'gc>, + mut stack: Stack<'gc, '_>, +) -> Result, Error<'gc>> { + let mut float_buf = [0u8; 300]; + + loop { + match state.inner { + FormatStateInner::Start => { + if let Some(next) = + memchr(FMT_SPEC, &state.str[state.index..]).map(|n| n + state.index) + { + if next != state.index { + state.buf.write_all(&state.str[state.index..next])?; + } + + let (spec, spec_end) = parse::parse_specifier(state.str.as_bytes(), next)?; + state.index = spec_end; + assert!(state.index > next); + + state.inner = FormatStateInner::EvaluateCallback { + spec, + dest: EvalContinuation::Init, + }; + } else { + if state.index < state.str.as_bytes().len() { + state.buf.write_all(&state.str[state.index..])?; + } + state.inner = FormatStateInner::End; + } + } + FormatStateInner::EvaluateCallback { spec, dest } => { + let result = stack.get(state.arg_count); + stack.resize(state.arg_count); + + let remaining_args = state.arg_count - state.value_index; + let mut values_iter = stack[state.value_index..state.arg_count].iter(); + let poll = evaluate_continuation( + ctx, + &mut state.buf, + dest, + spec, + Some(result), + &mut (&mut values_iter).copied(), + &mut float_buf, + )?; + state.value_index += remaining_args - values_iter.as_slice().len(); + + match poll { + EvalPoll::PassValue { value, then } => { + state.inner = FormatStateInner::EvaluateCallback { spec, dest: then }; + stack.push_back(value); + continue; + } + EvalPoll::Call { call, then } => { + state.inner = FormatStateInner::EvaluateCallback { spec, dest: then }; + let bottom = stack.len(); + stack.extend(call.args); + return Ok(SequencePoll::Call { + function: call.function, + bottom, + }); + } + EvalPoll::Done => { + state.inner = FormatStateInner::Start; + } + } + } + FormatStateInner::End => { + stack.replace(ctx, ctx.intern(&state.buf)); + return Ok(SequencePoll::Return); + } + }; + } +} + +fn evaluate_continuation<'gc, W: Write>( + ctx: Context<'gc>, + w: &mut W, + cont: EvalContinuation, + spec: FormatSpecifier, + result: Option>, + values: &mut impl Iterator>, + float_buf: &mut [u8; 300], +) -> Result, Error<'gc>> { + match cont { + EvalContinuation::ToStringResult(args) => { + let val = result.unwrap_or_default(); + let string = val + .into_string(ctx) + .ok_or_else(|| FormatError::BadValueType(spec.spec, "string", val.type_name()))?; + + let len = string.len() as usize; + let truncated_len = args.precision.unwrap_or(len).min(len); + + let pad = args.pad_num_before(w, truncated_len, 0, b"")?; + w.write_all(&string[..truncated_len])?; + pad.finish_pad(w)?; + Ok(EvalPoll::Done) + } + EvalContinuation::UnicodeToStringResult(args) => { + let val = result.unwrap_or_default(); + let string = val + .into_string(ctx) + .ok_or_else(|| FormatError::BadValueType(spec.spec, "string", val.type_name()))?; + + let string = core::str::from_utf8(string.as_bytes()) + .map_err(|_| FormatError::NonUnicodeString(spec.spec))?; + + let precision = args.precision.unwrap_or(string.len()); + + // Find the end byte of the string when truncated to `precision` chars. + let (end_byte, end_char) = string + .char_indices() + .map(|(i, c)| i + c.len_utf8()) + .zip(1..precision + 1) + .last() + .unwrap_or((0, 0)); + + let pad = args.pad_num_before(w, end_char, 0, b"")?; + w.write_all(string[..end_byte].as_bytes())?; + pad.finish_pad(w)?; + Ok(EvalPoll::Done) + } + EvalContinuation::Init => evaluate_specifier(ctx, w, spec, values, float_buf), + } +} + +fn evaluate_specifier<'gc, W: Write>( + ctx: Context<'gc>, + w: &mut W, + spec: FormatSpecifier, + values: &mut impl Iterator>, + float_buf: &mut [u8; 300], +) -> Result, Error<'gc>> { + match spec.spec { + b'%' => { + // escaped % + spec.check_flags(Flags::NONE)?; + w.write_all(b"%")?; + } + b'c' => { + // char + spec.check_flags(Flags::LEFT_ALIGN | Flags::WIDTH)?; + let args = spec.common_args(values)?; + + let int = spec.next_int(values)?; + let byte = (int % 256) as u8; + + let pad = args.pad_num_before(w, 1, 0, b"")?; + w.write_all(&[byte])?; + pad.finish_pad(w)?; + } + b'C' => { + // wide char + spec.check_flags(Flags::LEFT_ALIGN | Flags::WIDTH)?; + let args = spec.common_args(values)?; + + let int = spec.next_int(values)?; + let c: char = (u32::try_from(int).ok().and_then(char::from_u32)) + .ok_or(FormatError::ValueOutOfRange(spec.spec))?; + + let pad = args.pad_num_before(w, 1, 0, b"")?; + write!(w, "{}", c)?; + pad.finish_pad(w)?; + } + b's' => { + // string + spec.check_flags(Flags::LEFT_ALIGN | Flags::WIDTH | Flags::PRECISION)?; + let args = spec.common_args(values)?; + let val = spec.next_value(values)?; + // Continue in `evaluate_continuation` + return Ok(EvalPoll::from_metaresult( + meta_ops::tostring(ctx, val)?, + EvalContinuation::ToStringResult(args), + )); + } + b'S' => { + // utf-8 string + spec.check_flags(Flags::LEFT_ALIGN | Flags::WIDTH | Flags::PRECISION)?; + let args = spec.common_args(values)?; + let val = spec.next_value(values)?; + // Continue in `evaluate_continuation` + return Ok(EvalPoll::from_metaresult( + meta_ops::tostring(ctx, val)?, + EvalContinuation::UnicodeToStringResult(args), + )); + } + b'd' | b'i' => { + // signed int + spec.check_flags(Flags::SINT_ALLOWED)?; + let args = spec.common_args(values)?; + + let int = spec.next_int(values)?; + let value = int.unsigned_abs(); + let len = integer_length(value); + let sign = args.sign_char(int < 0); + + let zeroed_width = args.integer_zeroed_width(sign); + let pad = args.pad_num_before(w, len, zeroed_width, sign)?; + write!(w, "{}", value)?; + pad.finish_pad(w)?; + } + s @ (b'u' | b'o' | b'x' | b'X' | b'b' | b'B') => { + // unsigned int + if s == b'u' { + spec.check_flags(Flags::UINT_ALLOWED)?; + } else { + spec.check_flags(Flags::UINT_ALLOWED | Flags::ALTERNATE)?; + } + let args = spec.common_args(values)?; + let int = spec.next_int(values)? as u64; + + let len = match s { + b'x' | b'X' => integer_length_hex(int), + b'b' | b'B' => integer_length_binary(int), + b'o' => integer_length_octal(int), + b'u' => integer_length(int), + _ => unreachable!(), + }; + let prefix: &[u8] = match (args.alternate, s) { + (true, b'x') => b"0x", + (true, b'X') => b"0X", + (true, b'b') => b"0b", + (true, b'B') => b"0B", + (true, b'o') => b"0", + (true, b'u') => b"", + (_, _) => b"", + }; + + let zeroed_width = args.integer_zeroed_width(prefix); + let pad = args.pad_num_before(w, len, zeroed_width, prefix)?; + match s { + b'x' => write!(w, "{:x}", int)?, + b'X' => write!(w, "{:X}", int)?, + b'b' => write!(w, "{:b}", int)?, + b'B' => write!(w, "{:b}", int)?, + b'o' => write!(w, "{:o}", int)?, + b'u' => write!(w, "{}", int)?, + _ => unreachable!(), + } + pad.finish_pad(w)?; + } + c @ (b'g' | b'G' | b'e' | b'E' | b'f' | b'F' | b'a' | b'A') => { + // floating point number + spec.check_flags(Flags::ALL)?; + let args = spec.common_args(values)?; + + let mode = match c { + b'g' | b'G' => FloatMode::Compact, + b'e' | b'E' => FloatMode::Exponent, + b'f' | b'F' => FloatMode::Normal, + b'a' | b'A' => FloatMode::Hex, + _ => unreachable!(), + }; + let float = spec.next_float(values)?; + float::write_float(w, float, mode, args, float_buf)?; + } + b'p' => { + // object pointer + spec.check_flags(Flags::LEFT_ALIGN | Flags::WIDTH)?; + let args = spec.common_args(values)?; + + // TODO: Intentionally leaking addresses is a bad idea + // This defeats ASLR and simplifies potential exploits. + // (though addrs are currently already exposed through tostring on fns/tables) + let val = spec.next_value(values)?; + let ptr = match val { + Value::Nil | Value::Boolean(_) | Value::Integer(_) | Value::Number(_) => 0, + Value::String(str) => str.as_ptr() as usize, + Value::Table(t) => Gc::as_ptr(t.into_inner()) as usize, + Value::Function(Function::Closure(c)) => Gc::as_ptr(c.into_inner()) as usize, + Value::Function(Function::Callback(c)) => Gc::as_ptr(c.into_inner()) as usize, + Value::Thread(t) => Gc::as_ptr(t.into_inner()) as usize, + Value::UserData(u) => Gc::as_ptr(u.into_inner()) as usize, + }; + + if ptr != 0 { + let len = integer_length_hex(ptr as u64); + let pad = args.pad_num_before(w, len, 0, b"0x")?; + write!(w, "{:x}", ptr)?; + pad.finish_pad(w)?; + } else { + let null_str = "(null)"; + let pad = args.pad_num_before(w, null_str.len(), 0, b"")?; + write!(w, "{}", null_str)?; + pad.finish_pad(w)?; + } + } + b'q' => { + // Lua escape + spec.check_flags(Flags::NONE)?; + let val = spec.next_value(values)?; + write_escaped_value(w, val, spec)?; + } + c => return Err(FormatError::BadSpec(c).into()), + } + Ok(EvalPoll::Done) +} + +fn write_escaped_value<'gc, W: Write>( + w: &mut W, + val: Value<'gc>, + spec: FormatSpecifier, +) -> Result<(), Error<'gc>> { + match val { + Value::Nil => { + write!(w, "nil")?; + Ok(()) + } + Value::Boolean(b) => { + write!(w, "{}", b)?; + Ok(()) + } + Value::Integer(i) => { + if i == i64::MIN { + // MIN is not representable as positive, would be lexed as float + // PRLua outputs 0x8000000000000000 here, which is interpreted as + // a signed integer, but piccolo doesn't; instead we output a simple + // expression to avoid lexer issues. + write!(w, "({}-1)", i + 1)?; + } else { + write!(w, "{}", i)?; + } + Ok(()) + } + Value::Number(n) => { + // These encodings match PRLua's %q, but by the spec they just + // need to be able to round-trip as Lua expressions. + if n.is_finite() { + float::write_hex_float(w, n, FormatArgs::default())?; + } else if n.is_nan() { + write!(w, "(0/0)")?; + } else { + // +/- infinity + let sign = if n.is_sign_negative() { "-" } else { "" }; + write!(w, "{}1e9999", sign)?; + } + Ok(()) + } + Value::String(str) => { + write!(w, "\"")?; + for seg in str.as_bytes().utf8_chunks() { + let mut valid = seg.valid().chars().peekable(); + while let Some(c) = valid.next() { + match c { + c @ ('\\' | '"') => write!(w, "\\{}", c)?, + '\n' => write!(w, "\\\n")?, + c if c.is_ascii_control() => { + if matches!(valid.peek(), Some('0'..='9')) { + write!(w, "\\{:03}", c as u32)? + } else { + write!(w, "\\{}", c as u32)? + } + } + c => write!(w, "{}", c)?, + } + } + for c in seg.invalid() { + write!(w, "\\{}", *c)?; + } + } + write!(w, "\"")?; + Ok(()) + } + _ => Err(FormatError::BadValueType(spec.spec, "constant", val.type_name()).into()), + } +} diff --git a/src/stdlib/string/format/float.rs b/src/stdlib/string/format/float.rs new file mode 100644 index 00000000..ec9d142d --- /dev/null +++ b/src/stdlib/string/format/float.rs @@ -0,0 +1,303 @@ +use core::cmp::Ordering; +use std::io::Write; + +use super::{integer_length, FormatArgs, FormatError}; +use crate::Error; + +pub enum FloatMode { + Normal, + Exponent, + Compact, + Hex, +} + +pub fn write_float<'gc, W: Write>( + w: &mut W, + float: f64, + mode: FloatMode, + args: FormatArgs, + float_buf: &mut [u8], +) -> Result<(), Error<'gc>> { + let sign = args.sign_char(float.is_sign_negative()); + + let preserve_decimal = args.alternate; + let width = args.width; + let precision = args.precision.unwrap_or(6); + + if !float.is_finite() { + return write_nonfinite_float(w, float, args, sign).map_err(Into::into); + } + + if matches!(mode, FloatMode::Hex) { + return write_hex_float(w, float, args).map_err(Into::into); + } + + if matches!(mode, FloatMode::Compact | FloatMode::Exponent) { + let p = if matches!(mode, FloatMode::Compact) { + precision.saturating_sub(1) + } else { + precision + }; + let formatted = format_into_buffer(&mut *float_buf, format_args!("{:+.p$e}", float))?; + + let idx = formatted.rfind('e').ok_or(FormatError::BadFloat)?; + let exp = formatted[idx + 1..] + .parse::() + .map_err(|_| FormatError::BadFloat)?; + + // Note: Rust does not include a leading '+' in the exponent notation, but Lua does, + // so we must calculate the length manually. + let exp_len = 1 + integer_length(exp.unsigned_abs() as u64); + + // Implementation of %g, following the description of the algorithm + // in Python's documentation: + // https://docs.python.org/3/library/string.html#format-specification-mini-language + if matches!(mode, FloatMode::Compact) && exp >= -4 && (exp as i64) < (precision as i64) { + let p = (precision as i64 - 1 - exp as i64) as usize; + + let formatted_compact; + if preserve_decimal { + // Add a decimal at the end, in case Rust doesn't generate one; then strip it out + let s = format_into_buffer(&mut *float_buf, format_args!("{:+.p$}.", float))?; + if s[1..s.len() - 1].contains('.') { + formatted_compact = &s[1..s.len() - 1]; + } else { + formatted_compact = &s[1..]; + } + } else { + let s = format_into_buffer(&mut *float_buf, format_args!("{:+.p$}", float))?; + formatted_compact = strip_nonsignificant_zeroes(&s[1..]); + } + + let len = formatted_compact.len(); + let zero_width = if args.zero_pad { width } else { 0 }; + + let pad = args.pad_num_before(w, len, zero_width, sign)?; + write!(w, "{}", formatted_compact)?; + pad.finish_pad(w)?; + } else { + // exponent mode: + // [ ][-][000][a.bbb][e][+EE] + + let mut mantissa = &formatted[1..idx]; + if matches!(mode, FloatMode::Compact) && !preserve_decimal { + mantissa = strip_nonsignificant_zeroes(mantissa); + } + let e = if args.upper { 'E' } else { 'e' }; + + let exp_len = exp_len.max(3); + let len = mantissa.len() + 1 + exp_len; + let zero_width = if args.zero_pad { width } else { 0 }; + + if preserve_decimal && !formatted.contains('.') { + let pad = args.pad_num_before(w, len + 1, zero_width, sign)?; + write!(w, "{mantissa}.{e}{exp:+03}")?; + pad.finish_pad(w)?; + } else { + let pad = args.pad_num_before(w, len, zero_width, sign)?; + write!(w, "{mantissa}{e}{exp:+03}")?; + pad.finish_pad(w)?; + } + } + } else { + // normal float + // This can be larger than any reasonable buffer, so we have + // to forward everything to std (or find a float serialization + // library with custom formatting support.) + + // TODO: cannot support the '#' preserving decimal mode + // string.format("'%#.0f'", 1) should result in "1." + match (args.left_align, args.zero_pad, sign) { + (false, false, b"" | b"-") => write!(w, "{float:width$.precision$}")?, + (false, true, b"" | b"-") => write!(w, "{float:>0width$.precision$}")?, + (false, false, b"+") => write!(w, "{float:+width$.precision$}")?, + (false, true, b"+") => write!(w, "{float:>+0width$.precision$}")?, + (false, false, b" ") => write!(w, " {float:width$.precision$}")?, + (false, true, b" ") => write!(w, " {float:>0width$.precision$}")?, + (true, _, b"" | b"-") => write!(w, "{float: write!(w, "{float:<+width$.precision$}")?, + (true, _, b" ") => write!(w, " {float: unreachable!(), + } + } + Ok(()) +} + +fn write_nonfinite_float( + w: &mut W, + float: f64, + args: FormatArgs, + sign: &[u8], +) -> Result<(), std::io::Error> { + let s = match (float.is_infinite(), args.upper) { + (true, false) => "inf", + (true, true) => "INF", + (false, false) => "nan", + (false, true) => "NAN", + }; + let pad = args.pad_num_before(w, s.len(), 0, sign)?; + write!(w, "{s}")?; + pad.finish_pad(w)?; + Ok(()) +} + +const F64_EXPONENT_BITS: u32 = 11; +const F64_MANTISSA_BITS: u32 = 52; +const F64_EXP_OFFSET: i16 = -(1 << (F64_EXPONENT_BITS - 1)) + 1; + +#[inline] +const fn bitselect(n: u64, off: u32, count: u32) -> u64 { + (n >> off) & ((1 << count) - 1) +} + +fn round_mantissa(mantissa: u64, exp_bits: u16, precision: usize) -> (u64, u64) { + let leading_bit = (exp_bits != 0) as u64; + let mantissa = mantissa | (leading_bit << F64_MANTISSA_BITS); + let used_mantissa_bits = (precision as u32 * 4).min(F64_MANTISSA_BITS); + + let remainder_bits = F64_MANTISSA_BITS - used_mantissa_bits; + let quotient = mantissa >> remainder_bits; + let remainder = bitselect(mantissa, 0, remainder_bits); + let rounded_quotient = match remainder.cmp(&(1 << remainder_bits.saturating_sub(1))) { + Ordering::Less => quotient, + Ordering::Equal => (quotient + 1) & !1, // Round to even + Ordering::Greater => quotient + 1, + }; + + let head = rounded_quotient >> used_mantissa_bits; + let rounded_div_mantissa = bitselect(rounded_quotient, 0, used_mantissa_bits); + (head, rounded_div_mantissa) +} + +pub fn write_hex_float( + w: &mut W, + float: f64, + args: FormatArgs, +) -> Result<(), std::io::Error> { + let sign = args.sign_char(float.is_sign_negative()); + let preserve_decimal = args.alternate; + + if !float.is_finite() { + return write_nonfinite_float(w, float, args, sign); + } + + let width = args.width; + let mut precision = args + .precision + .unwrap_or(F64_MANTISSA_BITS.div_ceil(4) as usize); + + let bits = f64::to_bits(float); + let exp_bits = bitselect(bits, F64_MANTISSA_BITS, F64_EXPONENT_BITS); + // clamp exponent to -1022 for subnormals + let mut exp = (exp_bits as i16 + F64_EXP_OFFSET).max(-1022); + let mantissa = bitselect(bits, 0, F64_MANTISSA_BITS); + + if float == 0.0 { + exp = 0; + } + + let (head, mut mantissa) = round_mantissa(mantissa, exp_bits as u16, precision); + + let prefix: &[u8] = match (sign, args.upper) { + (b"", false) => b"0x", + (b"-", false) => b"-0x", + (b"+", false) => b"+0x", + (b" ", false) => b" 0x", + (b"", true) => b"0X", + (b"-", true) => b"-0X", + (b"+", true) => b"+0X", + (b" ", true) => b" 0X", + _ => unreachable!(), + }; + let zero_width = if args.zero_pad { + width.saturating_sub(prefix.len()) + } else { + 0 + }; + + if args.precision.is_none() { + let trailing_zero_digits = mantissa.trailing_zeros().min(F64_MANTISSA_BITS) / 4; + mantissa = mantissa >> (trailing_zero_digits * 4); + precision = precision.saturating_sub(trailing_zero_digits as usize); + } + + if precision != 0 { + let m_width = precision; + let len = 2 + m_width + 1 + 1 + integer_length(exp.unsigned_abs() as u64); + + let pad = args.pad_num_before(w, len, zero_width, prefix)?; + if args.upper { + write!(w, "{head}.{mantissa:0m_width$X}P{exp:+}")?; + } else { + write!(w, "{head}.{mantissa:0m_width$x}p{exp:+}")?; + } + pad.finish_pad(w)?; + } else { + let len = 3 + preserve_decimal as usize + integer_length(exp.unsigned_abs() as u64); + + let p = if args.upper { 'P' } else { 'p' }; + let pad = args.pad_num_before(w, len, zero_width, prefix)?; + if preserve_decimal { + write!(w, "{head}.{p}{exp:+}")?; + } else { + write!(w, "{head}{p}{exp:+}")?; + } + pad.finish_pad(w)?; + } + Ok(()) +} + +fn format_into_buffer<'a>( + buf: &'a mut [u8], + args: core::fmt::Arguments<'_>, +) -> Result<&'a str, core::fmt::Error> { + use core::fmt::Write; + + struct BufferWriter<'a> { + buffer: &'a mut [u8], + offset: usize, + } + + impl<'a> BufferWriter<'a> { + fn new(buffer: &'a mut [u8]) -> Self { + Self { buffer, offset: 0 } + } + fn into_str(self) -> &'a str { + let slice = &self.buffer[..self.offset]; + // Safety: `buffer` can only be filled by write_str, + // which requires valid utf8 strings. + unsafe { core::str::from_utf8_unchecked(slice) } + } + } + + impl Write for BufferWriter<'_> { + fn write_str(&mut self, s: &str) -> core::fmt::Result { + let bytes = s.as_bytes(); + let dest = self.buffer[self.offset..] + .get_mut(..bytes.len()) + .ok_or(core::fmt::Error)?; + dest.copy_from_slice(bytes); + self.offset += bytes.len(); + Ok(()) + } + } + + let mut writer = BufferWriter::new(buf); + write!(writer, "{}", args)?; + Ok(writer.into_str()) +} + +/// Remove trailing zeroes from a formatted number without changing its value. +fn strip_nonsignificant_zeroes(str: &str) -> &str { + if let Some(last_nonzero) = str.rfind(|p| p != '0') { + if let Some(decimal) = str[..=last_nonzero].rfind('.') { + // If the number ends with a trailing decimal point, remove it. + if decimal == last_nonzero { + return &str[..last_nonzero]; + } else { + return &str[..=last_nonzero]; + } + } + } + str +} diff --git a/src/stdlib/string/format/parse.rs b/src/stdlib/string/format/parse.rs new file mode 100644 index 00000000..de845a33 --- /dev/null +++ b/src/stdlib/string/format/parse.rs @@ -0,0 +1,94 @@ +use super::{Flags, FormatError, FormatSpecifier, OptionalArg, ARG_MAX, FMT_SPEC}; + +struct PeekableIter<'a> { + base: &'a [u8], + cur: &'a [u8], +} + +impl<'a> PeekableIter<'a> { + fn new(s: &'a [u8]) -> Self { + Self { base: s, cur: s } + } + fn peek(&mut self) -> Option { + self.cur.first().copied() + } + fn advance(&mut self) { + self.cur = &self.cur[1..]; + } + fn cur_index(&self) -> usize { + self.base.len() - self.cur.len() + } +} + +pub fn parse_specifier(str: &[u8], next: usize) -> Result<(FormatSpecifier, usize), FormatError> { + let mut iter = PeekableIter::new(&str[next + 1..]); + + let mut flags = Flags::NONE; + #[rustfmt::skip] + loop { + match iter.peek() { + Some(b'#') => { iter.advance(); flags |= Flags::ALTERNATE; }, + Some(b'-') => { iter.advance(); flags |= Flags::LEFT_ALIGN; }, + Some(b'+') => { iter.advance(); flags |= Flags::SIGN_FORCE; }, + Some(b' ') => { iter.advance(); flags |= Flags::SIGN_SPACE; }, + Some(b'0') => { iter.advance(); flags |= Flags::ZERO_PAD; }, + _ => break, + } + }; + + let width = try_parse_optional_arg(&mut iter).map_err(|_| FormatError::BadWidth)?; + if !matches!(width, OptionalArg::None) { + flags |= Flags::WIDTH; + } + + let precision = if let Some(b'.') = iter.peek() { + iter.advance(); + flags |= Flags::PRECISION; + let arg = try_parse_optional_arg(&mut iter).map_err(|_| FormatError::BadPrecision)?; + match arg { + OptionalArg::None => OptionalArg::Specified(0), + arg => arg, + } + } else { + OptionalArg::None + }; + + let spec = iter.peek().ok_or(FormatError::BadSpec(FMT_SPEC))?; + iter.advance(); + let spec_end = next + 1 + iter.cur_index(); + + let specifier = FormatSpecifier { + spec, + flags, + width, + precision, + }; + Ok((specifier, spec_end)) +} + +fn try_parse_optional_arg(iter: &mut PeekableIter<'_>) -> Result { + match iter.peek() { + Some(b'*') => { + iter.advance(); + Ok(OptionalArg::Arg) + } + Some(b'0'..=b'9') => { + let rest = &iter.cur[1..]; + let len = 1 + rest + .iter() + .position(|c| !c.is_ascii_digit()) + .unwrap_or(rest.len()); + + // We just verified that the slice is only composed of '0'..'9' + let slice = core::str::from_utf8(&iter.cur[..len]).map_err(|_| ())?; + + let num = slice.parse::().map_err(|_| ())?; + if num > ARG_MAX { + return Err(()); + } + iter.cur = &iter.cur[len..]; + Ok(OptionalArg::Specified(num)) + } + _ => Ok(OptionalArg::None), + } +} diff --git a/tests/scripts-wishlist/num-formats.lua b/tests/scripts-wishlist/num-formats.lua new file mode 100644 index 00000000..ac73de55 --- /dev/null +++ b/tests/scripts-wishlist/num-formats.lua @@ -0,0 +1,28 @@ + +local function dprint(...) + print(string.format("%q", ...)) +end + +local function assert_eq(val, exp) + -- dprint(val) + if val ~= exp then + error(string.format("assertion failed; expected %q but found %q", exp, val)) + end +end + + +do + -- A subnormal float with 1 more bit of precision than is representable + -- TODO: The lexer currently incorrectly parses this as 0 + local mismatch, alt = 0x1.fffffffffffffp-1023, 2.2250738585072013e-308 + assert_eq(string.format("%a %g", mismatch, mismatch), "0x1p-1022 2.22507e-308") + assert_eq(string.format("%a %g", alt, alt), "0x1p-1022 2.22507e-308") +end + +do + -- Fixing this requires rewriting the %f case of format_float to + -- either not use Rust's float formatter, or to have it write into + -- a (large!) intermediate buffer, and then edit parts of the + -- generated float. + assert_eq(string.format("'%#.0f'", 1), "'1.'") +end diff --git a/tests/scripts/format.lua b/tests/scripts/format.lua new file mode 100644 index 00000000..a241817a --- /dev/null +++ b/tests/scripts/format.lua @@ -0,0 +1,313 @@ + +local function dprint(...) + print(string.format("%q", ...)) +end + +local function assert_eq(val, exp) + -- dprint(val) + if val ~= exp then + error(string.format("assertion failed; expected %q but found %q", exp, val)) + end +end + + +do + assert_eq(string.format("abc%sdef", "iii"), "abciiidef") + + assert_eq(string.format("%s__", 123), "123__") + assert_eq(string.format("__%s", 321), "__321") + + assert_eq(string.format("%d %d %d", 1, 2, 3), "1 2 3") +end + + +do -- string width, truncating + assert_eq(string.format("%s", "example"), "example") + assert_eq(string.format("%5s", "a"), " a") + assert_eq(string.format("%-5s", "a"), "a ") + assert_eq(string.format("%-s", "example"), "example") + + assert_eq(string.format("%.3s", "example"), "exa") + assert_eq(string.format("%.s", "example"), "") + assert_eq(string.format("%5.2s", "example"), " ex") + assert_eq(string.format("%-5.2s", "example"), "ex ") + assert_eq(string.format("%2.s", "example"), " ") + + assert_eq(string.format("%.10s", "example"), "example") + assert_eq(string.format("%10.10s", "example"), " example") + assert_eq(string.format("%10.3s", "example"), " exa") + + assert_eq(string.format("%4s", ""), " ") + assert_eq(string.format("%4.2s", ""), " ") + assert_eq(string.format("%4.10s", ""), " ") +end + +do -- Integer formatting + assert_eq(string.format("'%8d'", -3), "' -3'") + assert_eq(string.format("'%8.4d'", 3), "' 0003'") + assert_eq(string.format("'%8.4d'", -3), "' -0003'") + assert_eq(string.format("'%+8.4d'", 3), "' +0003'") + assert_eq(string.format("'% 8.4d'", 3), "' 0003'") + assert_eq(string.format("'%08.4d'", 3), "' 0003'") + assert_eq(string.format("'%+08.4d'", 3), "' +0003'") + assert_eq(string.format("'%-08.4d'", 3), "'0003 '") + assert_eq(string.format("'%-+08.4d'", 3), "'+0003 '") + assert_eq(string.format("'%+08.4d'", 123456789), "'+123456789'") + + assert_eq(string.format("'% 1.4d'", 1234), "' 1234'") + assert_eq(string.format("'% 5.4d'", 1234), "' 1234'") + + -- '+' takes precedence over ' ' + assert_eq(string.format("'% +d'", 1), "'+1'") + assert_eq(string.format("'%+ d'", 1), "'+1'") + assert_eq(string.format("'%+ d'", -1), "'-1'") + + assert_eq(string.format("'%+08d'", 3), "'+0000003'") + assert_eq(string.format("'% 08d'", 3), "' 0000003'") + assert_eq(string.format("'%08d'", 3), "'00000003'") + assert_eq(string.format("'%+08d'", -3), "'-0000003'") + assert_eq(string.format("'% 08d'", -3), "'-0000003'") + assert_eq(string.format("'%08d'", -3), "'-0000003'") + + assert_eq(string.format("%08d", 1), "00000001") + assert_eq(string.format("%08d", -1), "-0000001") + assert_eq(string.format("%.8d", 1), "00000001") + assert_eq(string.format("%.8d", -1), "-00000001") + + assert_eq(string.format("%+08d", 1), "+0000001") + assert_eq(string.format("%+08d", -1), "-0000001") + assert_eq(string.format("%+.8d", 1), "+00000001") + assert_eq(string.format("%+.8d", -1), "-00000001") + + assert_eq(string.format("%#16.8x", 235678), " 0x0003989e") + + -- Unsigned integers + + assert_eq(string.format("'%8.4u'", 3), "' 0003'") + assert_eq(string.format("'%8.4u'", -3), "'18446744073709551613'") + assert_eq(string.format("'%08.4u'", 3), "' 0003'") + assert_eq(string.format("'%-08.4u'", 3), "'0003 '") + + assert_eq(string.format("'%1.4u'", 1234), "'1234'") + assert_eq(string.format("'%5.4u'", 1234), "' 1234'") + + assert_eq(string.format("'%08u'", 3), "'00000003'") + assert_eq(string.format("'%08u'", -3), "'18446744073709551613'") + assert_eq(string.format("'%8u'", 3), "' 3'") + + assert_eq(string.format("'%-8u'", 3), "'3 '") + assert_eq(string.format("'%-08u'", 3), "'3 '") +end + +do -- Floating point formatting + assert_eq(string.format("%g", 0), "0") + assert_eq(string.format("%g", -0.0), "-0") + assert_eq(string.format("%+g", 0), "+0") + assert_eq(string.format("% g", 0), " 0") + + assert_eq(string.format("%g", 1), "1") + assert_eq(string.format("%g", -1), "-1") + assert_eq(string.format("%+g", 1), "+1") + assert_eq(string.format("% g", 1), " 1") + + assert_eq(string.format("%g", 1.500001), "1.5") + assert_eq(string.format("%g", 1.50001), "1.50001") + assert_eq(string.format("%.1g", 1.5), "2") + assert_eq(string.format("%g", 1000), "1000") + assert_eq(string.format("%g", 100000), "100000") + assert_eq(string.format("%g", 1000000), "1e+06") + + assert_eq(string.format("%8g", 1), " 1") + assert_eq(string.format("%8g", 1.500001), " 1.5") + assert_eq(string.format("%8g", 1.50001), " 1.50001") + assert_eq(string.format("%8.1g", 1.5), " 2") + assert_eq(string.format("%8g", 1000), " 1000") + assert_eq(string.format("%8g", 100000), " 100000") + assert_eq(string.format("%8g", 1000000), " 1e+06") + assert_eq(string.format("%8G", 1000000), " 1E+06") + assert_eq(string.format("%8e", 1000000), "1.000000e+06") + assert_eq(string.format("%8E", 1000000), "1.000000E+06") + + assert_eq(string.format("%g", 0/0), "-nan") + assert_eq(string.format("%g", math.abs(0/0)), "nan") + assert_eq(string.format("%g", 1/0), "inf") + assert_eq(string.format("%g", -1/0), "-inf") + assert_eq(string.format("%G", 0/0), "-NAN") + assert_eq(string.format("%G", math.abs(0/0)), "NAN") + assert_eq(string.format("%G", 1/0), "INF") + assert_eq(string.format("%G", -1/0), "-INF") + + assert_eq(string.format("%05g", 0/0), " -nan") + assert_eq(string.format("%05g", math.abs(0/0)), " nan") + assert_eq(string.format("%05g", 1/0), " inf") + assert_eq(string.format("%05g", -1/0), " -inf") + + assert_eq(string.format("%05.8g", 0/0), " -nan") + assert_eq(string.format("%05.8g", math.abs(0/0)), " nan") + assert_eq(string.format("%05.8g", 1/0), " inf") + assert_eq(string.format("%05.8g", -1/0), " -inf") + + assert_eq(#string.format("%099.99f", 1.7976931348623158e308), 409) + + assert_eq(string.format("%0.0f", 15.1234), "15") + assert_eq(string.format("%4.0f", 15.1234), " 15") + assert_eq(string.format("'%-8.3f'", 15.1234), "'15.123 '") + assert_eq(string.format("'%-+8.3f'", 15.1234), "'+15.123 '") + assert_eq(string.format("'%0+8.3f'", 15.1234), "'+015.123'") + + -- assert_eq(string.format("'%#.0f'", 1), "'1.'") -- Can't easily fix this, relying on rust's std + assert_eq(string.format("'%.0f'", 1), "'1'") + assert_eq(string.format("'%g'", 1), "'1'") + assert_eq(string.format("'%#g'", 1), "'1.00000'") + assert_eq(string.format("'%#g'", 100000), "'100000.'") + + assert_eq(string.format("'%.0e'", 1), "'1e+00'") + assert_eq(string.format("'%#.0e'", 1), "'1.e+00'") + + -- Note: PRLua does not support %F (uppercase float; NAN/INF instead of nan/inf). + assert_eq(string.format("'%f'", 1.0/0.0), "'inf'") + assert_eq(string.format("'%e' '%E'", 1.0/0.0, 1.0/0.0), "'inf' 'INF'") + assert_eq(string.format("'%g' '%G'", 1.0/0.0, 1.0/0.0), "'inf' 'INF'") + + -- From glibc's manual: + -- https://www.gnu.org/software/libc/manual/html_node/Floating_002dPoint-Conversions.html + local expected = { + "| 0x0.0000p+0| 0.0000| 0.0000e+00| 0|", + "| 0x1.0000p-1| 0.5000| 5.0000e-01| 0.5|", + "| 0x1.0000p+0| 1.0000| 1.0000e+00| 1|", + "| -0x1.0000p+0| -1.0000| -1.0000e+00| -1|", + "| 0x1.9000p+6| 100.0000| 1.0000e+02| 100|", + "| 0x1.f400p+9| 1000.0000| 1.0000e+03| 1000|", + "| 0x1.3880p+13| 10000.0000| 1.0000e+04| 1e+04|", + "| 0x1.81c8p+13| 12345.0000| 1.2345e+04| 1.234e+04|", + "| 0x1.86a0p+16| 100000.0000| 1.0000e+05| 1e+05|", + "| 0x1.e240p+16| 123456.0000| 1.2346e+05| 1.235e+05|", + } + for i, v in ipairs({ 0, 0.5, 1, -1, 100, 1000, 10000, 12345, 1e5, 123456 }) do + assert_eq(string.format("|%13.4a|%13.4f|%13.4e|%13.4g|", v, v, v, v), expected[i]) + end + + assert_eq(string.format("%13.4a", 0.0/0.0), " -nan") + assert_eq(string.format("%13.4a", 1.0/0.0), " inf") + assert_eq(string.format("%13.4a", math.pi), " 0x1.9220p+1") + assert_eq(string.format("%13.4a", math.huge), " inf") + + assert_eq(string.format("%13.4a", 0x1.00008p+0), " 0x1.0000p+0") + assert_eq(string.format("%13.4a", 0x1.000081p+0), " 0x1.0001p+0") + assert_eq(string.format("%13.4a", 0x1.00007fp+0), " 0x1.0000p+0") + + -- Round to even + assert_eq(string.format("%13.4a", 0x1.00018p+0), " 0x1.0002p+0") + assert_eq(string.format("%13.4a", 0x1.000181p+0), " 0x1.0002p+0") + assert_eq(string.format("%13.4a", 0x1.00017fp+0), " 0x1.0001p+0") + + assert_eq(string.format("%13a", math.pi), "0x1.921fb54442d18p+1") + assert_eq(string.format("%013.1a", math.pi), "0x000001.9p+1") + assert_eq(string.format("%013.0a", math.pi), "0x00000002p+1") + + assert_eq(string.format("%#013.0a", 4.0), "0x0000001.p+2") + assert_eq(string.format("%#013.0a", math.pi), "0x0000002.p+1") + + assert_eq(string.format("%+#15.8g", 0x1.0p+64), " +1.8446744e+19") + assert_eq(string.format("%+#15.8g", 0x1.0p-64), " +5.4210109e-20") +end + +do -- alternate integer formats (hex, octal) + assert_eq(string.format("'%8.4x'", 0xAB), "' 00ab'") + assert_eq(string.format("'%#8.4x'", 0xAB), "' 0x00ab'") + assert_eq(string.format("'%08.4x'", 0xAB), "' 00ab'") + assert_eq(string.format("'%#08.4x'", 0xAB), "' 0x00ab'") + assert_eq(string.format("'%08x'", 0xAB), "'000000ab'") + assert_eq(string.format("'%#08x'", 0xAB), "'0x0000ab'") + assert_eq(string.format("'%08.x'", 0xAB), "' ab'") + assert_eq(string.format("'%#08.x'", 0xAB), "' 0xab'") +end + +do -- Floating-point edgecases + assert_eq(string.format("%a", 0x1.fffffffffffffp+0), "0x1.fffffffffffffp+0") + assert_eq(string.format("%.7a", 0x1.fffffffffffffp+0), "0x2.0000000p+0") + assert_eq(string.format("%.7a", 0x1.88fffffffffffp+0), "0x1.8900000p+0") + + local smallest_normal = 0x1p-1022 + local largest_subnormal = 0x0.fffffffffffffp-1022 + local largest_subnormal_equiv = 0x1.ffffffffffffep-1023 + local smallest_subnormal = 0x1p-1074 + + assert_eq(string.format("%a %g", smallest_normal, smallest_normal), "0x1p-1022 2.22507e-308") + assert_eq(string.format("%a %g", largest_subnormal, largest_subnormal), "0x0.fffffffffffffp-1022 2.22507e-308") + assert_eq(string.format("%a %g", smallest_subnormal, smallest_subnormal), "0x0.0000000000001p-1022 4.94066e-324") + + local small = 0x0.fffffffffffffp-1020 + assert_eq(string.format("%a %g", small, small), "0x1.ffffffffffffep-1021 8.9003e-308") + + -- A subnormal float with 1 more bit of precision than is representable + -- TODO: The lexer currently incorrectly parses this as 0 + local mismatch, alt = 0x1.fffffffffffffp-1023, 2.2250738585072013e-308 + -- assert_eq(string.format("%a %g", mismatch, mismatch), "0x1p-1022 2.22507e-308") + assert_eq(string.format("%a %g", alt, alt), "0x1p-1022 2.22507e-308") + + local n = 0x1.888p-1022 + assert_eq(string.format("%a %g", n, n), "0x1.888p-1022 3.41149e-308") + + -- test preserve decimal + assert_eq(string.format("%#a %g", smallest_normal, smallest_normal), "0x1.p-1022 2.22507e-308") + + -- test fixed precision + assert_eq(string.format("%.13a %g", smallest_normal, smallest_normal), "0x1.0000000000000p-1022 2.22507e-308") +end + +-- TODO: implement load for "%q" round-trip testing + +do -- "%q" round-tripping fo strings + local s = "" + -- table.concat would be nice here... + for i = 1, 127 do + s = s .. string.format("%c", i) + end + s = string.format("%q", s) + assert_eq(s, [["\1\2\3\4\5\6\7\8\9\ +\11\12\13\14\15\16\17\18\19\20\21\22\23\24\25\26\27\28\29\30\31 !\"]].. +[[#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`ab]].. +[[cdefghijklmnopqrstuvwxyz{|}~\127"]]) + + assert_eq(string.format("%q", "asdf\nabc\"def"), "\"asdf\\\nabc\\\"def\"") + + assert_eq(string.format("%q", "\0" .. "\1"), "\"\\0\\1\"") + assert_eq(string.format("%q", "\0" .. "1"), "\"\\0001\"") +end + +do + assert_eq(string.format("%q", (0.0/0)), "(0/0)") + assert_eq(string.format("%q", (1.0/0)), "1e9999") + assert_eq(string.format("%q", -(0.0/0)), "(0/0)") + assert_eq(string.format("%q", -(1.0/0)), "-1e9999") + assert_eq(string.format("%q %q", 1.0, math.pi), "0x1p+0 0x1.921fb54442d18p+1") + + assert_eq(string.format("%q", 15.345678), "0x1.eb0fcb4f1e4b4p+3") + + -- TODO: impl load, make sure values round-trip + -- assert_eq(string.format("%q", math.mininteger), "0x8000000000000000") + -- assert_eq(string.format("%q", math.mininteger), "(-9223372036854775807-1)") +end + +do -- Metamethod testing + stringable = setmetatable({}, { __tostring = function () + return "abc" + end }) + + assert_eq(string.format("'%s'", stringable), "'abc'") + -- assert_eq(string.format("%q", stringable), "\"abc\"") +end + +-- Piccolo-specific tests +if piccolo then + assert_eq(string.format("%*s", 3, "a"), " a") + assert_eq(string.format("%-*s", 3, "a"), "a ") + assert_eq(string.format("%*s", -3, "a"), "a ") + + assert_eq(string.format("%.2S", "aa😀bb"), "aa") + assert_eq(string.format("%.3S", "aa😀bb"), "aa😀") + assert_eq(string.format("%S", "aa😀bb"), "aa😀bb") + + assert_eq(string.format("%C", 0x1F600), "😀") +end From c65c28d7b12d626708290ce6aa57d4edcb66cd31 Mon Sep 17 00:00:00 2001 From: Aeledfyr Date: Tue, 29 Apr 2025 22:47:49 -0500 Subject: [PATCH 2/2] Support subnormal hexfloat literals --- src/compiler/string_utils.rs | 50 +++++++++++++++++++++++++++++++++++- tests/scripts/format.lua | 3 +-- tests/scripts/subnormal.lua | 46 +++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 tests/scripts/subnormal.lua diff --git a/src/compiler/string_utils.rs b/src/compiler/string_utils.rs index b94497cc..00c94138 100644 --- a/src/compiler/string_utils.rs +++ b/src/compiler/string_utils.rs @@ -242,8 +242,56 @@ pub fn read_hex_float(s: &[u8]) -> Option { if is_neg { base = -base; } + Some(ldexp(base, exp)) +} + +fn ldexp(mut val: f64, exp: i32) -> f64 { + fn iexp2(exp: i32) -> f64 { + assert!(exp >= -1022 && exp <= 1023); + f64::from_bits(((exp + 1023 as i32) as u64) << 52) + } + fn extract_exp(val: f64) -> u64 { + (val.to_bits() >> 52) & 0x7ff + } + fn extract_mantissa(val: f64) -> u64 { + val.to_bits() & ((1 << 52) - 1) + } + + let mut orig_exp = extract_exp(val) as i32; + let mantissa = extract_mantissa(val); + if orig_exp == 0 { + if mantissa == 0 { + return val; + } + // input is subnormal + val *= iexp2(54); + orig_exp = extract_exp(val) as i32 - 54; + } else if orig_exp == 2047 { + // input is NaN of Inf + return val; + } - Some(base * (exp as f64).exp2()) + let new_exp = orig_exp.saturating_add(exp); + if new_exp >= 2047 { + f64::copysign(f64::INFINITY, val) // overflow + } else if new_exp <= -54 { + f64::copysign(0.0, val) // underflow + } else { + if new_exp > 0 { + f64::from_bits( + ((val.is_sign_negative() as u64) << 63) + | ((new_exp as u64) << 52) + | (mantissa as u64), + ) + } else { + // Output is subnormal + f64::from_bits( + ((val.is_sign_negative() as u64) << 63) + | (((new_exp + 54) as u64) << 52) + | (mantissa as u64), + ) * iexp2(-54) + } + } } /// Read an optional '-' or '+' prefix and return whether the value is negated (starts with a '-' diff --git a/tests/scripts/format.lua b/tests/scripts/format.lua index a241817a..0d3fdbf1 100644 --- a/tests/scripts/format.lua +++ b/tests/scripts/format.lua @@ -241,9 +241,8 @@ do -- Floating-point edgecases assert_eq(string.format("%a %g", small, small), "0x1.ffffffffffffep-1021 8.9003e-308") -- A subnormal float with 1 more bit of precision than is representable - -- TODO: The lexer currently incorrectly parses this as 0 local mismatch, alt = 0x1.fffffffffffffp-1023, 2.2250738585072013e-308 - -- assert_eq(string.format("%a %g", mismatch, mismatch), "0x1p-1022 2.22507e-308") + assert_eq(string.format("%a %g", mismatch, mismatch), "0x1p-1022 2.22507e-308") assert_eq(string.format("%a %g", alt, alt), "0x1p-1022 2.22507e-308") local n = 0x1.888p-1022 diff --git a/tests/scripts/subnormal.lua b/tests/scripts/subnormal.lua new file mode 100644 index 00000000..56dd8103 --- /dev/null +++ b/tests/scripts/subnormal.lua @@ -0,0 +1,46 @@ + +local function dprint(...) + print(string.format("%q", ...)) +end + +local function assert_eq(val, exp) + -- dprint(val) + if val ~= exp then + error(string.format("assertion failed; expected %q but found %q", exp, val)) + end +end + +local a = 0x1.fffffffffffffp-1023 +assert_eq(string.format("%a", a, a), "0x1p-1022") +local b = 0x1.ffffffffffffep-1023 +assert_eq(string.format("%a", b, b), "0x0.fffffffffffffp-1022") +print("a") +local c = 0x1.ffffffffffffdp-1023 +assert_eq(string.format("%a", c, c), "0x0.ffffffffffffep-1022") +print("b") +local d = 0x1.ffffffffffffcp-1023 +assert_eq(string.format("%a", d, d), "0x0.ffffffffffffep-1022") +local d = 0x1.ffffffffffff0p-1026 +assert_eq(string.format("%a", d, d), "0x0.1ffffffffffffp-1022") +local d = 0x1.ffffffffffff7p-1026 +assert_eq(string.format("%a", d, d), "0x0.1ffffffffffffp-1022") +local d = 0x1.ffffffffffff8p-1026 +assert_eq(string.format("%a", d, d), "0x0.2p-1022") + + +local a = -0x1.fffffffffffffp-1023 +assert_eq(string.format("%a", a, a), "-0x1p-1022") +local b = -0x1.ffffffffffffep-1023 +assert_eq(string.format("%a", b, b), "-0x0.fffffffffffffp-1022") +local c = -0x1.ffffffffffffdp-1023 +assert_eq(string.format("%a", c, c), "-0x0.ffffffffffffep-1022") +local d = -0x1.ffffffffffffcp-1023 +assert_eq(string.format("%a", d, d), "-0x0.ffffffffffffep-1022") + +local mismatch, alt = 0x1.fffffffffffffp-1023, 2.2250738585072013e-308 +assert_eq(string.format("%a %g", mismatch, mismatch), "0x1p-1022 2.22507e-308") +assert_eq(string.format("%a %g", alt, alt), "0x1p-1022 2.22507e-308") + +local a, b = 0x1.ffffffffffffep-1023, 2.22507385850720088902e-308 +assert_eq(string.format("%a %g", a, a), "0x0.fffffffffffffp-1022 2.22507e-308") +assert_eq(string.format("%a %g", b, b), "0x0.fffffffffffffp-1022 2.22507e-308")