diff --git a/mojoproject.toml b/mojoproject.toml index 075e92f..f0147b9 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -33,9 +33,13 @@ t = "clear && magic run test" bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .." b = "clear && magic run bench" +# individual bench files +b_mul = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .." +b_div = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .." + # before commit final = "magic run test && magic run bench" f = "clear && magic run final" [dependencies] -max = ">=25.1,<25.3" +max = ">=25.1,<25.3" \ No newline at end of file diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index b85bac0..80e3331 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -876,7 +876,7 @@ struct Decimal( fn __rsub__(self, other: Int) raises -> Self: return decimojo.subtract(Decimal(other), self) - fn __mul__(self, other: Decimal) -> Self: + fn __mul__(self, other: Decimal) raises -> Self: """ Multiplies two Decimal values and returns a new Decimal containing the product. """ @@ -886,7 +886,7 @@ struct Decimal( fn __mul__(self, other: Float64) raises -> Self: return decimojo.multiply(self, Decimal(other)) - fn __mul__(self, other: Int) -> Self: + fn __mul__(self, other: Int) raises -> Self: return decimojo.multiply(self, Decimal(other)) fn __truediv__(self, other: Decimal) raises -> Self: @@ -1133,6 +1133,18 @@ struct Decimal( """Returns True if this Decimal is NaN (Not a Number).""" return (self.flags & Self.NAN_MASK) != 0 + fn is_uint32able(self) -> Bool: + """ + Returns True if the coefficient can be represented as a UInt32 value. + """ + return self.high == 0 and self.mid == 0 + + fn is_uint64able(self) -> Bool: + """ + Returns True if the coefficient can be represented as a UInt64 value. + """ + return self.high == 0 + fn scale(self) -> Int: """Returns the scale (number of decimal places) of this Decimal.""" return Int((self.flags & Self.SCALE_MASK) >> Self.SCALE_SHIFT) diff --git a/src/decimojo/maths/arithmetics.mojo b/src/decimojo/maths/arithmetics.mojo index b5abab6..0197bc4 100644 --- a/src/decimojo/maths/arithmetics.mojo +++ b/src/decimojo/maths/arithmetics.mojo @@ -29,6 +29,7 @@ from decimojo.rounding_mode import RoundingMode # ===----------------------------------------------------------------------=== # +# TODO: Like `multiply` use combined bits to determine the appropriate method fn add(x1: Decimal, x2: Decimal) raises -> Decimal: """ Adds two Decimal values and returns a new Decimal containing the sum. @@ -120,7 +121,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: var scale = min( max(x1.scale(), x2.scale()), Decimal.MAX_VALUE_DIGITS - - decimojo.utility.number_of_significant_digits(summation), + - decimojo.utility.number_of_digits(summation), ) ## If summation > 7922816251426433759354395033 if (summation > Decimal.MAX_AS_UINT128 // 10) and (scale > 0): @@ -149,7 +150,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: var scale = min( max(x1.scale(), x2.scale()), Decimal.MAX_VALUE_DIGITS - - decimojo.utility.number_of_significant_digits(diff), + - decimojo.utility.number_of_digits(diff), ) ## If summation > 7922816251426433759354395033 if (diff > Decimal.MAX_AS_UINT128 // 10) and (scale > 0): @@ -187,10 +188,10 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: truncated_summation = decimojo.utility.truncate_to_max( truncated_summation ) - final_scale = decimojo.utility.number_of_significant_digits( + final_scale = decimojo.utility.number_of_digits( truncated_summation ) - ( - decimojo.utility.number_of_significant_digits(summation) + decimojo.utility.number_of_digits(summation) - max(x1.scale(), x2.scale()) ) @@ -240,10 +241,10 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: truncated_summation = decimojo.utility.truncate_to_max( truncated_summation ) - final_scale = decimojo.utility.number_of_significant_digits( + final_scale = decimojo.utility.number_of_digits( truncated_summation ) - ( - decimojo.utility.number_of_significant_digits(summation) + decimojo.utility.number_of_digits(summation) - max(x1.scale(), x2.scale()) ) @@ -286,7 +287,7 @@ fn subtract(x1: Decimal, x2: Decimal) raises -> Decimal: raise Error("Error in `subtract()`; ", e) -fn multiply(x1: Decimal, x2: Decimal) -> Decimal: +fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: """ Multiplies two Decimal values and returns a new Decimal containing the product. @@ -297,127 +298,299 @@ fn multiply(x1: Decimal, x2: Decimal) -> Decimal: Returns: A new Decimal containing the product of x1 and x2. """ - # Special cases for zero - if x1.is_zero() or x2.is_zero(): - # For zero, we need to preserve the scale + + var x1_coef = x1.coefficient() + var x2_coef = x2.coefficient() + var x1_scale = x1.scale() + var x2_scale = x2.scale() + var combined_scale = x1_scale + x2_scale + """Combined scale of the two operands.""" + var is_nagative = x1.is_negative() != x2.is_negative() + + # SPECIAL CASE: zero + # Return zero while preserving the scale + if x1_coef == 0 or x2_coef == 0: var result = Decimal.ZERO() - var result_scale = min(x1.scale() + x2.scale(), Decimal.MAX_PRECISION) + var result_scale = min(combined_scale, Decimal.MAX_PRECISION) result.flags = UInt32( (result_scale << Decimal.SCALE_SHIFT) & Decimal.SCALE_MASK ) return result - # Calculate the combined scale (sum of both scales) - var combined_scale = x1.scale() + x2.scale() + # SPECIAL CASE: Both operands have coefficient of 1 + if x1_coef == 1 and x2_coef == 1: + # If the combined scale exceeds the maximum precision, + # return 0 with leading zeros after the decimal point and correct sign + if combined_scale > Decimal.MAX_PRECISION: + return Decimal( + 0, + 0, + 0, + is_nagative, + Decimal.MAX_PRECISION, + ) + # Otherwise, return 1 with correct sign and scale + var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + return Decimal(1, 0, 0, is_nagative, final_scale) + + # SPECIAL CASE: First operand has coefficient of 1 + if x1_coef == 1: + # If x1 is 1, return x2 with correct sign + if x1_scale == 0: + var result = x2 + result.flags &= ~Decimal.SIGN_MASK + if is_nagative: + result.flags |= Decimal.SIGN_MASK + return result + else: + var mul = x2_coef + # Rounding may be needed. + var num_digits_mul = decimojo.utility.number_of_digits(mul) + var num_digits_to_keep = num_digits_mul - ( + combined_scale - Decimal.MAX_PRECISION + ) + var truncated_mul = decimojo.utility.truncate_to_digits( + mul, num_digits_to_keep + ) + var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + var low = UInt32(truncated_mul & 0xFFFFFFFF) + var mid = UInt32((truncated_mul >> 32) & 0xFFFFFFFF) + var high = UInt32((truncated_mul >> 64) & 0xFFFFFFFF) + return Decimal( + low, + mid, + high, + is_nagative, + final_scale, + ) - # Determine the sign of the result (XOR of signs) - var result_is_negative = x1.is_negative() != x2.is_negative() + # SPECIAL CASE: Second operand has coefficient of 1 + if x2_coef == 1: + # If x2 is 1, return x1 with correct sign + if x2_scale == 0: + var result = x1 + result.flags &= ~Decimal.SIGN_MASK + if is_nagative: + result.flags |= Decimal.SIGN_MASK + return result + else: + var mul = x1_coef + # Rounding may be needed. + var num_digits_mul = decimojo.utility.number_of_digits(mul) + var num_digits_to_keep = num_digits_mul - ( + combined_scale - Decimal.MAX_PRECISION + ) + var truncated_mul = decimojo.utility.truncate_to_digits( + mul, num_digits_to_keep + ) + var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + var low = UInt32(truncated_mul & 0xFFFFFFFF) + var mid = UInt32((truncated_mul >> 32) & 0xFFFFFFFF) + var high = UInt32((truncated_mul >> 64) & 0xFFFFFFFF) + return Decimal( + low, + mid, + high, + is_nagative, + final_scale, + ) + + # Determine the number of bits in the coefficients + # Used to determine the appropriate multiplication method + # The coefficient of result would be the sum of the two numbers of bits + var x1_num_bits = decimojo.utility.number_of_bits(x1_coef) + """Number of significant bits in the coefficient of x1.""" + var x2_num_bits = decimojo.utility.number_of_bits(x2_coef) + """Number of significant bits in the coefficient of x2.""" + var combined_num_bits = x1_num_bits + x2_num_bits + """Number of significant bits in the coefficient of the result.""" + + # SPECIAL CASE: Both operands are true integers + if x1_scale == 0 and x2_scale == 0: + # Small integers, use UInt64 multiplication + if combined_num_bits <= 64: + var mul: UInt64 = UInt64(x1.low) * UInt64(x2.low) + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + return Decimal(low, mid, 0, is_nagative, 0) + + # Moderate integers, use UInt128 multiplication + elif combined_num_bits <= 128: + var mul: UInt128 = UInt128(x1_coef) * UInt128(x2_coef) + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, is_nagative, 0) + + # Large integers, use UInt256 multiplication + else: + var mul: UInt256 = UInt256(x1_coef) * UInt256(x2_coef) + if mul > Decimal.MAX_AS_UINT256: + raise Error("Error in `multiply()`: Decimal overflow") + else: + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, is_nagative, 0) + + # SPECIAL CASE: Both operands are integers but with scales + # Examples: 123.0 * 456.00 + if x1.is_integer() and x2.is_integer(): + var x1_integral_part = x1_coef // (UInt128(10) ** UInt128(x1_scale)) + var x2_integral_part = x2_coef // (UInt128(10) ** UInt128(x2_scale)) + var mul: UInt256 = UInt256(x1_integral_part) * UInt256(x2_integral_part) + if mul > Decimal.MAX_AS_UINT256: + raise Error("Error in `multiply()`: Decimal overflow") + else: + var num_digits = decimojo.utility.number_of_digits(mul) + var final_scale = min( + Decimal.MAX_VALUE_DIGITS - num_digits, combined_scale + ) + # Scale up before it overflows + mul = mul * 10**final_scale + if mul > Decimal.MAX_AS_UINT256: + mul = mul // 10 + final_scale -= 1 + + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + return Decimal( + low, + mid, + high, + is_nagative, + final_scale, + ) + + # GENERAL CASES: Decimal multiplication with any scales - # Extract the components for multiplication - var a_low = UInt64(x1.low) - var a_mid = UInt64(x1.mid) - var a_high = UInt64(x1.high) - - var b_low = UInt64(x2.low) - var b_mid = UInt64(x2.mid) - var b_high = UInt64(x2.high) - - # Perform 96-bit by 96-bit multiplication - var r0 = a_low * b_low - var r1_a = a_low * b_mid - var r1_b = a_mid * b_low - var r2_a = a_low * b_high - var r2_b = a_mid * b_mid - var r2_c = a_high * b_low - var r3_a = a_mid * b_high - var r3_b = a_high * b_mid - var r4 = a_high * b_high - - # Accumulate results with carries - var c0 = r0 & 0xFFFFFFFF - var c1 = (r0 >> 32) + (r1_a & 0xFFFFFFFF) + (r1_b & 0xFFFFFFFF) - var c2 = (r1_a >> 32) + (r1_b >> 32) + (r2_a & 0xFFFFFFFF) + ( - r2_b & 0xFFFFFFFF - ) + (r2_c & 0xFFFFFFFF) + (c1 >> 32) - c1 = c1 & 0xFFFFFFFF # Mask after carry - - var c3 = (r2_a >> 32) + (r2_b >> 32) + (r2_c >> 32) + ( - r3_a & 0xFFFFFFFF - ) + (r3_b & 0xFFFFFFFF) + (c2 >> 32) - c2 = c2 & 0xFFFFFFFF # Mask after carry - - var c4 = (r3_a >> 32) + (r3_b >> 32) + (c3 >> 32) + r4 - c3 = c3 & 0xFFFFFFFF # Mask after carry - - var result_low = UInt32(c0) - var result_mid = UInt32(c1) - var result_high = UInt32(c2) - - # If we have overflow, we need to adjust the scale by dividing - # BUT ONLY enough to fit the result in 96 bits - no more - var scale_reduction = 0 - if c3 > 0 or c4 > 0: - # Calculate minimum shifts needed to fit the result - while c3 > 0 or c4 > 0: - var remainder = UInt64(0) - - # Process c4 - var new_c4 = c4 / 10 - remainder = c4 % 10 - - # Process c3 with remainder from c4 - var new_c3 = (remainder << 32 | c3) / 10 - remainder = (remainder << 32 | c3) % 10 - - # Process c2 with remainder from c3 - var new_c2 = (remainder << 32 | c2) / 10 - remainder = (remainder << 32 | c2) % 10 - - # Process c1 with remainder from c2 - var new_c1 = (remainder << 32 | c1) / 10 - remainder = (remainder << 32 | c1) % 10 - - # Process c0 with remainder from c1 - var new_c0 = (remainder << 32 | c0) / 10 - - # Update values - c4 = new_c4 - c3 = new_c3 - c2 = new_c2 - c1 = new_c1 - c0 = new_c0 - - scale_reduction += 1 - - # Update result components after shifting - result_low = UInt32(c0) - result_mid = UInt32(c1) - result_high = UInt32(c2) - - # Create the result with adjusted values - var result = Decimal(result_low, result_mid, result_high, 0) - - # IMPORTANT: We account for the scale reduction separately from MAX_PRECISION capping - # First, apply the technical scale reduction needed due to overflow - var adjusted_scale = combined_scale - scale_reduction - - # THEN cap at MAX_PRECISION - var final_scale = min(adjusted_scale, Decimal.MAX_SCALE) - - # Set the flags with the correct scale - result.flags = UInt32( - (final_scale << Decimal.SCALE_SHIFT) & Decimal.SCALE_MASK + # SUB-CASE: Both operands are small + # The bits of the product will not exceed 96 bits + # It can just fit into Decimal's capacity without overflow + # Result coefficient will less than 2^96 - 1 = 79228162514264337593543950335 + # Examples: 1.23 * 4.56 + if combined_num_bits <= 96: + var mul: UInt128 = x1_coef * x2_coef + + # Combined scale more than max precision, no need to truncate + if combined_scale <= Decimal.MAX_PRECISION: + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, is_nagative, combined_scale) + + # Combined scale no more than max precision, truncate with rounding + else: + var num_digits = decimojo.utility.number_of_digits(mul) + var num_digits_to_keep = num_digits - ( + combined_scale - Decimal.MAX_PRECISION + ) + mul = decimojo.utility.truncate_to_digits(mul, num_digits_to_keep) + var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, is_nagative, final_scale) + + # SUB-CASE: Both operands are moderate + # The bits of the product will not exceed 128 bits + # Result coefficient will less than 2^128 - 1 but more than 2^96 - 1 + # IMPORTANT: This means that the product will exceed Decimal's capacity + # Either raises an error if intergral part overflows + # Or truncates the product to fit into Decimal's capacity + if combined_num_bits <= 128: + var mul: UInt128 = x1_coef * x2_coef + + # Check outflow + # The number of digits of the integral part + var num_digits_of_integral_part = decimojo.utility.number_of_digits( + mul + ) - combined_scale + # Truncated first 29 digits + var truncated_mul_at_max_length = decimojo.utility.truncate_to_digits( + mul, Decimal.MAX_VALUE_DIGITS + ) + if (num_digits_of_integral_part >= Decimal.MAX_VALUE_DIGITS) & ( + truncated_mul_at_max_length > Decimal.MAX_AS_UINT128 + ): + raise Error("Error in `multiply()`: Decimal overflow") + + # Otherwise, the value will not overflow even after rounding + # Determine the final scale after rounding + # If the first 29 digits does not exceed the limit, + # the final coefficient can be of 29 digits. + # The final scale can be 29 - num_digits_of_integral_part. + var num_digits_of_decimal_part = Decimal.MAX_VALUE_DIGITS - num_digits_of_integral_part + # If the first 29 digits exceed the limit, + # we need to adjust the num_digits_of_decimal_part by -1 + # so that the final coefficient will be of 28 digits. + if truncated_mul_at_max_length > Decimal.MAX_AS_UINT128: + num_digits_of_decimal_part -= 1 + mul = decimojo.utility.truncate_to_digits( + mul, Decimal.MAX_VALUE_DIGITS - 1 + ) + else: + mul = truncated_mul_at_max_length + + # I think combined_scale should always be smaller + var final_scale = min(num_digits_of_decimal_part, combined_scale) + + # Extract the 32-bit components from the UInt128 product + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, is_nagative, final_scale) + + # REMAINING CASES: Both operands are big + # The bits of the product will not exceed 192 bits + # Result coefficient will less than 2^192 - 1 but more than 2^128 - 1 + # IMPORTANT: This means that the product will exceed Decimal's capacity + # Either raises an error if intergral part overflows + # Or truncates the product to fit into Decimal's capacity + var mul: UInt256 = UInt256(x1_coef) * UInt256(x2_coef) + + # Check outflow + # The number of digits of the integral part + var num_digits_of_integral_part = decimojo.utility.number_of_digits( + mul + ) - combined_scale + # Truncated first 29 digits + var truncated_mul_at_max_length = decimojo.utility.truncate_to_digits( + mul, Decimal.MAX_VALUE_DIGITS ) - if result_is_negative: - result.flags |= Decimal.SIGN_MASK + # Check for overflow of the integral part after rounding + if (num_digits_of_integral_part >= Decimal.MAX_VALUE_DIGITS) & ( + truncated_mul_at_max_length > Decimal.MAX_AS_UINT256 + ): + raise Error("Error in `multiply()`: Decimal overflow") + + # Otherwise, the value will not overflow even after rounding + # Determine the final scale after rounding + # If the first 29 digits does not exceed the limit, + # the final coefficient can be of 29 digits. + # The final scale can be 29 - num_digits_of_integral_part. + var num_digits_of_decimal_part = Decimal.MAX_VALUE_DIGITS - num_digits_of_integral_part + # If the first 29 digits exceed the limit, + # we need to adjust the num_digits_of_decimal_part by -1 + # so that the final coefficient will be of 28 digits. + if truncated_mul_at_max_length > Decimal.MAX_AS_UINT256: + num_digits_of_decimal_part -= 1 + mul = decimojo.utility.truncate_to_digits( + mul, Decimal.MAX_VALUE_DIGITS - 1 + ) + else: + mul = truncated_mul_at_max_length - # Handle excess precision separately AFTER handling overflow - # (this shouldn't be reducing scale twice) - if adjusted_scale > Decimal.MAX_PRECISION: - var scale_diff = adjusted_scale - Decimal.MAX_PRECISION - result = result._scale_down(scale_diff, RoundingMode.HALF_EVEN()) + # I think combined_scale should always be smaller + final_scale = min(num_digits_of_decimal_part, combined_scale) - return result + # Extract the 32-bit components from the UInt256 product + var low = UInt32(mul & 0xFFFFFFFF) + var mid = UInt32((mul >> 32) & 0xFFFFFFFF) + var high = UInt32((mul >> 64) & 0xFFFFFFFF) + + return Decimal(low, mid, high, is_nagative, final_scale) fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: diff --git a/src/decimojo/maths/rounding.mojo b/src/decimojo/maths/rounding.mojo index c5391d7..dfdbb7b 100644 --- a/src/decimojo/maths/rounding.mojo +++ b/src/decimojo/maths/rounding.mojo @@ -29,7 +29,7 @@ from decimojo.rounding_mode import RoundingMode fn round( number: Decimal, - decimal_places: Int, + decimal_places: Int = 0, rounding_mode: RoundingMode = RoundingMode.HALF_EVEN(), ) -> Decimal: """ @@ -38,17 +38,25 @@ fn round( Args: number: The Decimal to round. decimal_places: Number of decimal places to round to. - rounding_mode: Rounding mode to use (defaults to HALF_EVEN/banker's rounding). + Defaults to 0. + rounding_mode: Rounding mode to use. + Defaults to HALF_EVEN/banker's rounding. Returns: A new Decimal rounded to the specified number of decimal places. """ var current_scale = number.scale() - # If already at the desired scale, return a copy + # CASE: If already at the desired scale + # Return a copy + # round(Decimal("123.456"), 3) -> Decimal("123.456") if current_scale == decimal_places: return number + # TODO: CASE: If the number is an integer + # Return with more or less zeros until the desired scale + # round(Decimal("123"), 2) -> Decimal("123.00") + # If we need more decimal places, scale up if decimal_places > current_scale: return number._scale_up(decimal_places - current_scale) diff --git a/src/decimojo/utility.mojo b/src/decimojo/utility.mojo index 9e20727..3ed8431 100644 --- a/src/decimojo/utility.mojo +++ b/src/decimojo/utility.mojo @@ -50,8 +50,8 @@ fn bitcast[dtype: DType](dec: Decimal) -> Scalar[dtype]: fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: """ - Truncates a UInt256 or UInt128 value to maximum possible value of Decimal - coefficient with rounding. + Truncates a UInt256 or UInt128 value to be as closer to the max value of + Decimal coefficient (`2^96 - 1`) as possible with rounding. Uses banker's rounding (ROUND_HALF_EVEN) for any truncated digits. `792281625142643375935439503356` will be truncated to `7922816251426433759354395034`. @@ -64,19 +64,19 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: Args: value: The UInt256 value to truncate. + Constraints: + `dtype` must be either `DType.uint128` or `DType.uint256`. + Returns: The truncated UInt256 value, guaranteed to fit within 96 bits. """ alias ValueType = Scalar[dtype] - # TODO: Make this compile-time check instead of rasing an error - # @parameter - # if (dtype != DType.uint128) and (dtype != DType.uint256): - # raise Error( - # "Error in `truncate_to_max`: dtype must be either uint128 or" - # " uint256." - # ) + constrained[ + dtype == DType.uint128 or dtype == DType.uint256, + "must be uint128 or uint256", + ]() # If the value is already less than the maximum possible value, return it if value <= ValueType(Decimal.MAX_AS_UINT128): @@ -85,7 +85,7 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: else: # Calculate how many digits we need to truncate # Calculate how many digits to keep (MAX_VALUE_DIGITS = 29) - var num_digits = number_of_significant_digits(value) + var num_digits = number_of_digits(value) var digits_to_remove = num_digits - Decimal.MAX_VALUE_DIGITS # Collect digits for rounding decision @@ -176,16 +176,168 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: return truncated_value -fn number_of_significant_digits[dtype: DType, //](x: Scalar[dtype]) -> Int: +# TODO: Evalulate whether this can replace truncate_to_max in some cases. +# TODO: Add rounding modes to this function. +fn truncate_to_digits[ + dtype: DType, // +](value: Scalar[dtype], num_digits: Int) -> Scalar[dtype]: """ - Returns the number of significant digits in a scalar value. - ***WARNING***: The input must be an integer. + Truncates a UInt256 or UInt128 value to the specified number of digits. + Uses banker's rounding (ROUND_HALF_EVEN) for any truncated digits. + `792281625142643375935439503356` with digits 2 will be truncated to `79`. + `997` with digits 2 will be truncated to `100`. + + This is useful in two cases: + (1) When you want to evaluate whether the coefficient will overflow after + rounding, just look the first N digits (after rounding). If the truncated + value is larger than the maximum, then it will overflow. Then you need to + either raise an error (in case scale = 0 or integral part overflows), + or keep only the first 28 digits in the coefficient. + (2) When you want to round a value. + + The function is useful in the following cases. + + When you want to apply a scale of 31 to the coefficient `997`, it will be + `0.0000000000000000000000000000997` with 31 digits. However, we can only + store 28 digits in the coefficient (Decimal.MAX_PRECISION = 28). + Therefore, we need to truncate the coefficient to 0 (`3 - (31 - 28)`) digits + and round it to the nearest even number. + The truncated ceofficient will be `1`. + Note that `truncated_digits = 1` which is not equal to + `num_digits = 0`, meaning there is a rounding to next digit. + The final decimal value will be `0.0000000000000000000000000001`. + + When you want to apply a scale of 29 to the coefficient `234567`, it will be + `0.00000000000000000000000234567` with 29 digits. However, we can only + store 28 digits in the coefficient (Decimal.MAX_PRECISION = 28). + Therefore, we need to truncate the coefficient to 5 (`6 - (29 - 28)`) digits + and round it to the nearest even number. + The truncated ceofficient will be `23457`. + The final decimal value will be `0.0000000000000000000000023457`. + + When you want to apply a scale of 5 to the coefficient `234567`, it will be + `2.34567` with 5 digits. + Since `num_digits_to_keep = 6 - (5 - 28) = 29`, + it is greater and equal to the number of digits of the input value. + The function will return the value as it is. + + It can also be used for rounding function. For example, if you want to round + `12.34567` (`1234567` with scale `5`) to 2 digits, + the function input will be `234567` and `4 = (7 - 5) + 2`. + That is (number of digits - scale) + number of rounding points. + The output is `1235`. + + Parameters: + dtype: Must be either uint128 or uint256. + + Args: + value: The UInt256 value to truncate. + num_digits: The number of significant digits to evalulate. + + Constraints: + `dtype` must be either `DType.uint128` or `DType.uint256`. + + Returns: + The truncated UInt256 value, guaranteed to fit within 96 bits. """ - var temp = x - var digit_count: Int = 0 - while temp > 0: - temp //= 10 - digit_count += 1 + alias ValueType = Scalar[dtype] + + constrained[ + dtype == DType.uint128 or dtype == DType.uint256, + "must be uint128 or uint256", + ]() + + if num_digits < 0: + return 0 + + var num_significant_digits = number_of_digits(value) + # If the number of digits is less than or equal to the specified digits, + # return the value + if num_significant_digits <= num_digits: + return value + + else: + # Calculate how many digits we need to truncate + # Calculate how many digits to keep (MAX_VALUE_DIGITS = 29) + var num_digits_to_remove = num_significant_digits - num_digits + + # Collect digits for rounding decision + divisor = ValueType(10) ** ValueType(num_digits_to_remove) + truncated_value = value // divisor + var remainder = value % divisor + + # Get the most significant digit of the remainder for rounding + var rounding_digit = remainder // 10 ** (num_digits_to_remove - 1) + + # Check if we need to round up based on banker's rounding (ROUND_HALF_EVEN) + var round_up = False + + # If rounding digit is > 5, round up + if rounding_digit > 5: + round_up = True + # If rounding digit is 5, check if there are any non-zero digits after it + elif rounding_digit == 5: + var has_nonzero_after = remainder > 5 * 10 ** ( + num_digits_to_remove - 1 + ) + # If there are non-zero digits after, round up + if has_nonzero_after: + round_up = True + # Otherwise, round to even (round up if last kept digit is odd) + else: + round_up = (truncated_value % 2) == 1 + + # Apply rounding if needed + if round_up: + truncated_value += 1 + + return truncated_value + + +fn number_of_digits[dtype: DType, //](owned value: Scalar[dtype]) -> Int: + """ + Returns the number of (significant) digits in an intergral value. + + Constraints: + `dtype` must be integral. + """ + + constrained[ + dtype.is_integral(), + "must be intergral", + ]() + + if value < 0: + value = -value + + var count = 0 + while value > 0: + value //= 10 + count += 1 + + return count + + +fn number_of_bits[dtype: DType, //](owned value: Scalar[dtype]) -> Int: + """ + Returns the number of significant bits in an integer value. + + Constraints: + `dtype` must be integral. + """ + + constrained[ + dtype.is_integral(), + "must be intergral", + ]() + + if value < 0: + value = -value + + var count = 0 + while value > 0: + value >>= 1 + count += 1 - return digit_count + return count diff --git a/tests/test_utility.mojo b/tests/test_utility.mojo index 952a717..5a4a652 100644 --- a/tests/test_utility.mojo +++ b/tests/test_utility.mojo @@ -6,39 +6,39 @@ from testing import assert_equal, assert_true import max from decimojo.prelude import * -from decimojo.utility import truncate_to_max, number_of_significant_digits +from decimojo.utility import truncate_to_max, number_of_digits -fn test_number_of_significant_digits() raises: - """Tests for number_of_significant_digits function.""" - print("Testing number_of_significant_digits...") +fn test_number_of_digits() raises: + """Tests for number_of_digits function.""" + print("Testing number_of_digits...") # Test with simple UInt128 values - assert_equal(number_of_significant_digits(UInt128(0)), 0) - assert_equal(number_of_significant_digits(UInt128(1)), 1) - assert_equal(number_of_significant_digits(UInt128(9)), 1) - assert_equal(number_of_significant_digits(UInt128(10)), 2) - assert_equal(number_of_significant_digits(UInt128(123)), 3) - assert_equal(number_of_significant_digits(UInt128(9999)), 4) + assert_equal(number_of_digits(UInt128(0)), 0) + assert_equal(number_of_digits(UInt128(1)), 1) + assert_equal(number_of_digits(UInt128(9)), 1) + assert_equal(number_of_digits(UInt128(10)), 2) + assert_equal(number_of_digits(UInt128(123)), 3) + assert_equal(number_of_digits(UInt128(9999)), 4) # Test with powers of 10 - assert_equal(number_of_significant_digits(UInt128(10**6)), 7) - assert_equal(number_of_significant_digits(UInt128(10**12)), 13) + assert_equal(number_of_digits(UInt128(10**6)), 7) + assert_equal(number_of_digits(UInt128(10**12)), 13) # Test with UInt256 values - assert_equal(number_of_significant_digits(UInt256(0)), 0) - assert_equal(number_of_significant_digits(UInt256(123456789)), 9) - assert_equal(number_of_significant_digits(UInt256(10) ** 20), 21) + assert_equal(number_of_digits(UInt256(0)), 0) + assert_equal(number_of_digits(UInt256(123456789)), 9) + assert_equal(number_of_digits(UInt256(10) ** 20), 21) # Test with large values approaching UInt128 maximum var large_value = UInt128(Decimal.MAX_AS_UINT128) - assert_equal(number_of_significant_digits(large_value), 29) + assert_equal(number_of_digits(large_value), 29) # Test with values larger than UInt128 max (using UInt256) var very_large = UInt256(Decimal.MAX_AS_UINT128) * UInt256(10) - assert_equal(number_of_significant_digits(very_large), 30) + assert_equal(number_of_digits(very_large), 30) - print("✓ All number_of_significant_digits tests passed!") + print("✓ All number_of_digits tests passed!") fn test_truncate_to_max_below_max() raises: @@ -161,6 +161,63 @@ fn test_truncate_to_max_banker_rounding() raises: print("✓ All truncate_to_max banker's rounding tests passed!") +fn test_truncate_to_digits() raises: + """Test the truncate_to_digits function for proper digit truncation and rounding. + """ + print("Testing truncate_to_digits...") + + # Test case 1: Value with more digits than to keep (round to nearest power of 10) + var case1 = UInt128(997) + var case1_expected = UInt128(1) + assert_equal(dm.utility.truncate_to_digits(case1, 0), case1_expected) + + # Test case 2: Value with one more digit than to keep + var case2 = UInt128(234567) + var case2_expected = UInt128(23457) + assert_equal(dm.utility.truncate_to_digits(case2, 5), case2_expected) + + # Test case 3: Value with fewer digits than to keep (should return original) + var case3 = UInt128(234567) + assert_equal(dm.utility.truncate_to_digits(case3, 29), case3) + + # Test case 4: Test banker's rounding with 5 (round to even) + var case4a = UInt128(12345) # Last digit is 5, preceding digit is even + var case4a_expected = UInt128(1234) + assert_equal(dm.utility.truncate_to_digits(case4a, 4), case4a_expected) + + var case4b = UInt128(23455) # Last digit is 5, preceding digit is odd + var case4b_expected = UInt128(2346) + assert_equal(dm.utility.truncate_to_digits(case4b, 4), case4b_expected) + + # Test case 5: Rounding down (< 5) + var case5 = UInt128(12342) + var case5_expected = UInt128(1234) + assert_equal(dm.utility.truncate_to_digits(case5, 4), case5_expected) + + # Test case 6: Rounding up (> 5) + var case6 = UInt128(12347) + var case6_expected = UInt128(1235) + assert_equal(dm.utility.truncate_to_digits(case6, 4), case6_expected) + + # Test case 7: Zero input + var case7 = UInt128(0) + assert_equal(dm.utility.truncate_to_digits(case7, 5), UInt128(0)) + + # Test case 8: Single digit input + var case8 = UInt128(7) + assert_equal(dm.utility.truncate_to_digits(case8, 1), UInt128(7)) + assert_equal( + dm.utility.truncate_to_digits(case8, 0), UInt128(1) + ) # Round to nearest power of 10 + + # Test case 9: Large value with UInt256 + var case9 = UInt256(9876543210987654321) + var case9_expected = UInt256(987654321098765432) + assert_equal(dm.utility.truncate_to_digits(case9, 18), case9_expected) + + print("✓ All truncate_to_digits tests passed!") + + fn test_bitcast() raises: """Test the bitcast utility function for direct memory bit conversion.""" print("Testing utility.bitcast...") @@ -204,12 +261,11 @@ fn test_bitcast() raises: print("✓ All bitcast tests passed!") -# Update the test_all function to include the new test fn test_all() raises: """Run all tests for the utility module.""" print("\n=== Running Utility Module Tests ===\n") - test_number_of_significant_digits() + test_number_of_digits() print() test_truncate_to_max_below_max() @@ -224,6 +280,9 @@ fn test_all() raises: test_bitcast() print() + test_truncate_to_digits() + print() + print("✓✓✓ All utility module tests passed! ✓✓✓")