diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 7c24fe4..fdb03bd 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -13,8 +13,8 @@ jobs: strategy: fail-fast: false matrix: - os: ["macos-latest"] - # os: ["ubuntu-22.04"] + # os: ["macos-latest"] + os: ["ubuntu-22.04"] runs-on: ${{ matrix.os }} timeout-minutes: 30 diff --git a/benches/bench.mojo b/benches/bench.mojo index 5e020a8..ac40f3e 100644 --- a/benches/bench.mojo +++ b/benches/bench.mojo @@ -6,6 +6,7 @@ from bench_sqrt import main as bench_sqrt from bench_from_float import main as bench_from_float from bench_from_string import main as bench_from_string from bench_comparison import main as bench_comparison +from bench_exp import main as bench_exp fn main() raises: @@ -17,3 +18,4 @@ fn main() raises: bench_from_float() bench_from_string() bench_comparison() + bench_exp() diff --git a/benches/bench_exp.mojo b/benches/bench_exp.mojo index 195fbac..a669eda 100644 --- a/benches/bench_exp.mojo +++ b/benches/bench_exp.mojo @@ -71,7 +71,7 @@ fn run_benchmark( var mojo_decimal = Decimal(input_value) var pydecimal = Python.import_module("decimal") var py_decimal = pydecimal.Decimal(input_value) - var py_math = Python.import_module("math") + var _py_math = Python.import_module("math") # Execute the operations once to verify correctness var mojo_result = dm.exponential.exp(mojo_decimal) diff --git a/mojoproject.toml b/mojoproject.toml index 93fc70c..53f4595 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -33,7 +33,7 @@ debug_sqrt = "magic run package && magic run mojo tests/test_sqrt.mojo && magic test = "magic run package && magic run mojo test tests && magic run delete_package" t = "clear && magic run test" test_arith = "magic run package && magic run mojo test tests/test_arithmetics.mojo && magic run delete_package" -test_div = "magic run package && magic run mojo test tests/test_division.mojo && magic run delete_package" +test_divide = "magic run package && magic run mojo test tests/test_division.mojo && magic run delete_package" test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && magic run delete_package" test_round = "magic run package && magic run mojo test tests/test_round.mojo && magic run delete_package" test_creation = "magic run package && magic run mojo test tests/test_creation.mojo && magic run delete_package" @@ -46,8 +46,8 @@ test_exp = "magic run package && magic run mojo test tests/test_exp.mojo && magi # benches bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" b = "clear && magic run bench" -bench_mul = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .. && magic run delete_package" -bench_div = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .. && magic run delete_package" +bench_multiply = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .. && magic run delete_package" +bench_divide = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .. && magic run delete_package" bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo && cd .. && magic run delete_package" bench_round = "magic run package && cd benches && magic run mojo bench_round.mojo && cd .. && magic run delete_package" bench_from_float = "magic run package && cd benches && magic run mojo bench_from_float.mojo && cd .. && magic run delete_package" diff --git a/src/decimojo/arithmetics.mojo b/src/decimojo/arithmetics.mojo index 32ab77d..15c9566 100644 --- a/src/decimojo/arithmetics.mojo +++ b/src/decimojo/arithmetics.mojo @@ -36,6 +36,7 @@ Implements functions for mathematical operations on Decimal objects. """ +import time import testing from decimojo.decimal import Decimal @@ -64,53 +65,57 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: var x1_scale = x1.scale() var x2_scale = x2.scale() - # Special case for zeros - + # CASE: Zeros if x1_coef == 0 and x2_coef == 0: var scale = max(x1_scale, x2_scale) return Decimal(0, 0, 0, scale, False) elif x1_coef == 0: - var sum_coef = x2_coef - var scale = min( - max(x1_scale, x2_scale), - Decimal.MAX_NUM_DIGITS - - decimojo.utility.number_of_digits(x2.to_uint128()), - ) - ## If x2_coef > 7922816251426433759354395033 - if ( - (x2_coef > Decimal.MAX_AS_UINT128 // 10) - and (scale > 0) - and (scale > x2_scale) - ): - scale -= 1 - sum_coef *= UInt128(10) ** (scale - x2_scale) - var low = UInt32(sum_coef & 0xFFFFFFFF) - var mid = UInt32((sum_coef >> 32) & 0xFFFFFFFF) - var high = UInt32((sum_coef >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, scale, x2.is_negative()) + if x1_scale <= x2_scale: + return x2 + + else: # x1_scale > x2_scale + # Scale up x2_coef to match x1_scale + + var sum_coef = x2_coef + var scale = min( + max(x1_scale, x2_scale), + Decimal.MAX_NUM_DIGITS + - decimojo.utility.number_of_digits(x2.to_uint128()), + ) + ## If x2_coef > 7922816251426433759354395033 + if ( + (x2_coef > Decimal.MAX_AS_UINT128 // 10) + and (scale > 0) + and (scale > x2_scale) + ): + scale -= 1 + sum_coef *= UInt128(10) ** (scale - x2_scale) + return Decimal.from_uint128(sum_coef, scale, x2.is_negative()) elif x2_coef == 0: - var sum_coef = x1_coef - var scale = min( - max(x1_scale, x2_scale), - Decimal.MAX_NUM_DIGITS - - decimojo.utility.number_of_digits(x1.to_uint128()), - ) - ## If x1_coef > 7922816251426433759354395033 - if ( - (x1_coef > Decimal.MAX_AS_UINT128 // 10) - and (scale > 0) - and (scale > x1_scale) - ): - scale -= 1 - sum_coef *= UInt128(10) ** (scale - x1_scale) - var low = UInt32(sum_coef & 0xFFFFFFFF) - var mid = UInt32((sum_coef >> 32) & 0xFFFFFFFF) - var high = UInt32((sum_coef >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, scale, x1.is_negative()) - - # Integer addition with scale of 0 (true integers) + if x2_scale <= x1_scale: + return x1 + + else: # x2_scale > x1_scale + # Scale up x1_coef to match x2_scale + var sum_coef = x1_coef + var scale = min( + max(x1_scale, x2_scale), + Decimal.MAX_NUM_DIGITS + - decimojo.utility.number_of_digits(x1.to_uint128()), + ) + ## If x1_coef > 7922816251426433759354395033 + if ( + (x1_coef > Decimal.MAX_AS_UINT128 // 10) + and (scale > 0) + and (scale > x1_scale) + ): + scale -= 1 + sum_coef *= UInt128(10) ** (scale - x1_scale) + return Decimal.from_uint128(sum_coef, scale, x1.is_negative()) + + # CASE: Integer addition with scale of 0 (true integers) elif x1_scale == 0 and x2_scale == 0: # Same sign: add absolute values and keep the sign if x1.is_negative() == x2.is_negative(): @@ -147,7 +152,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: return Decimal(low, mid, high, 0, is_negative) - # Integer addition with positive scales + # CASE: Integer addition with positive scales elif x1.is_integer() and x2.is_integer(): # Same sign: add absolute values and keep the sign if x1.is_negative() == x2.is_negative(): @@ -206,94 +211,130 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: return Decimal(low, mid, high, scale, is_negative) - # Float addition with the same scale + # CASE: Float addition with the same scale elif x1_scale == x2_scale: - var summation: Int128 # 97-bit signed integer can be stored in Int128 - summation = (-1) ** x1.is_negative() * Int128(x1_coef) + ( - -1 - ) ** x2.is_negative() * Int128(x2_coef) - - var is_negative = summation < 0 - if is_negative: - summation = -summation + var summation: UInt128 + var is_negative: Bool - # Now we need to truncate the summation to fit in 96 bits - var final_scale: Int - var truncated_summation = UInt128(summation) + if x1.is_negative() == x2.is_negative(): + is_negative = x1.is_negative() + summation = x1_coef + x2_coef + else: # Different signs + if x1_coef > x2_coef: + summation = x1_coef - x2_coef + is_negative = x1.is_negative() + elif x1_coef < x2_coef: + summation = x2_coef - x1_coef + is_negative = x2.is_negative() + else: # x1_coef == x2_coef + return Decimal.from_uint128(UInt128(0), x1_scale, False) # If the summation fits in 96 bits, we can use the original scale - if summation < Decimal.MAX_AS_INT128: - final_scale = x1_scale + if summation < Decimal.MAX_AS_UINT128: + return Decimal.from_uint128(summation, x1_scale, is_negative) - # Otherwise, we need to truncate the summation to fit in 96 bits + # Otherwise, it is >= 29 digits + # we need to truncate the summation to fit in 96 bits else: - truncated_summation = decimojo.utility.truncate_to_max( - truncated_summation + var ndigits_summation = decimojo.utility.number_of_digits(summation) + var ndigits_int_summation = ndigits_summation - x1_scale + var final_scale = Decimal.MAX_NUM_DIGITS - ndigits_int_summation + + var truncated_summation = decimojo.utility.round_to_keep_first_n_digits( + summation, Decimal.MAX_NUM_DIGITS ) - final_scale = decimojo.utility.number_of_digits( - truncated_summation - ) - ( - decimojo.utility.number_of_digits(summation) - - max(x1_scale, x2_scale) + if truncated_summation > Decimal.MAX_AS_UINT128: + truncated_summation = ( + decimojo.utility.round_to_keep_first_n_digits( + summation, Decimal.MAX_NUM_DIGITS - 1 + ) + ) + final_scale -= 1 + + return Decimal.from_uint128( + truncated_summation, final_scale, is_negative ) - # Extract the 32-bit components from the Int256 difference - low = UInt32(truncated_summation & 0xFFFFFFFF) - mid = UInt32((truncated_summation >> 32) & 0xFFFFFFFF) - high = UInt32((truncated_summation >> 64) & 0xFFFFFFFF) + # CASE: Float addition which with different scales + else: # x1_scale != x2_scale + var summation: UInt256 + var is_negative: Bool - return Decimal(low, mid, high, final_scale, is_negative) - - # Float addition which with different scales - else: - var summation: Int256 - if x1_scale == x2_scale: - summation = (-1) ** x1.is_negative() * Int256(x1_coef) + ( - -1 - ) ** x2.is_negative() * Int256(x2_coef) - elif x1_scale > x2_scale: - summation = (-1) ** x1.is_negative() * Int256(x1_coef) + ( - -1 - ) ** x2.is_negative() * Int256(x2_coef) * Int256(10) ** ( + if x1_scale > x2_scale: + # Scale up x2_coef to match x1_scale + var x1_coef_scaled: UInt256 = UInt256(x1_coef) + var x2_coef_scaled: UInt256 = UInt256(x2_coef) * UInt256(10) ** ( x1_scale - x2_scale ) - else: - summation = (-1) ** x1.is_negative() * Int256(x1_coef) * Int256( - 10 - ) ** (x2_scale - x1_scale) + (-1) ** x2.is_negative() * Int256( - x2_coef - ) - var is_negative = summation < 0 - if is_negative: - summation = -summation + if x1.is_negative() == x2.is_negative(): + is_negative = x1.is_negative() + summation = x1_coef_scaled + x2_coef_scaled + else: # Different signs + if x1_coef_scaled > x2_coef_scaled: + summation = x1_coef_scaled - x2_coef_scaled + is_negative = x1.is_negative() + elif x1_coef_scaled < x2_coef_scaled: + summation = x2_coef_scaled * x1_coef_scaled + is_negative = x2.is_negative() + else: + return Decimal.from_uint128(UInt128(0), x1_scale, False) - # Now we need to truncate the summation to fit in 96 bits - var final_scale: Int - var truncated_summation = UInt256(summation) + else: # x1_scale < x2_scale + # Scale up x1_coef to match x2_scale + var x1_coef_scaled: UInt256 = UInt256(x1_coef) * UInt256(10) ** ( + x2_scale - x1_scale + ) + var x2_coef_scaled: UInt256 = UInt256(x2_coef) + + if x1.is_negative() == x2.is_negative(): + is_negative = x1.is_negative() + summation = x2_coef_scaled + x1_coef_scaled + else: # Different signs + if x1_coef_scaled > x2_coef_scaled: + summation = x1_coef_scaled - x2_coef_scaled + is_negative = x1.is_negative() + elif x1_coef_scaled < x2_coef_scaled: + summation = x2_coef_scaled - x1_coef_scaled + is_negative = x2.is_negative() + else: + return Decimal.from_uint128(UInt128(0), x2_scale, False) # If the summation fits in 96 bits, we can use the original scale - if summation < Decimal.MAX_AS_INT256: - final_scale = max(x1_scale, x2_scale) + if summation < Decimal.MAX_AS_UINT256: + return Decimal.from_uint128( + UInt128(summation & 0x00000000_FFFFFFFF_FFFFFFFF_FFFFFFFF), + max(x1_scale, x2_scale), + is_negative, + ) + # Otherwise, it is >= 29 digits # Otherwise, we need to truncate the summation to fit in 96 bits else: - truncated_summation = decimojo.utility.truncate_to_max( - truncated_summation - ) - final_scale = decimojo.utility.number_of_digits( - truncated_summation - ) - ( - decimojo.utility.number_of_digits(summation) - - max(x1_scale, x2_scale) + var ndigits_summation = decimojo.utility.number_of_digits(summation) + var ndigits_int_summation = ndigits_summation - max( + x1_scale, x2_scale ) + var final_scale = Decimal.MAX_NUM_DIGITS - ndigits_int_summation - # Extract the 32-bit components from the Int256 difference - low = UInt32(truncated_summation & 0xFFFFFFFF) - mid = UInt32((truncated_summation >> 32) & 0xFFFFFFFF) - high = UInt32((truncated_summation >> 64) & 0xFFFFFFFF) + truncated_summation = decimojo.utility.round_to_keep_first_n_digits( + summation, Decimal.MAX_NUM_DIGITS + ) + if truncated_summation > Decimal.MAX_AS_UINT256: + truncated_summation = ( + decimojo.utility.round_to_keep_first_n_digits( + summation, Decimal.MAX_NUM_DIGITS - 1 + ) + ) + final_scale -= 1 - return Decimal(low, mid, high, final_scale, is_negative) + return Decimal.from_uint128( + UInt128( + truncated_summation & 0x00000000_FFFFFFFF_FFFFFFFF_FFFFFFFF + ), + final_scale, + is_negative, + ) fn subtract(x1: Decimal, x2: Decimal) raises -> Decimal: @@ -594,18 +635,19 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # IMPORTANT: This means that the product will exceed Decimal's capacity # Either raises an error if intergral part overflows # Or truncates the product to fit into Decimal's capacity + if combined_num_bits <= 128: var prod: UInt128 = x1_coef * x2_coef + # Truncated first 29 digits + var truncated_prod_at_max_length = decimojo.utility.round_to_keep_first_n_digits( + prod, Decimal.MAX_NUM_DIGITS + ) # Check outflow # The number of digits of the integral part var num_digits_of_integral_part = decimojo.utility.number_of_digits( prod ) - combined_scale - # Truncated first 29 digits - var truncated_prod_at_max_length = decimojo.utility.round_to_keep_first_n_digits( - prod, Decimal.MAX_NUM_DIGITS - ) if (num_digits_of_integral_part >= Decimal.MAX_NUM_DIGITS) & ( truncated_prod_at_max_length > Decimal.MAX_AS_UINT128 ): @@ -651,17 +693,20 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # IMPORTANT: This means that the product will exceed Decimal's capacity # Either raises an error if intergral part overflows # Or truncates the product to fit into Decimal's capacity + var prod: UInt256 = UInt256(x1_coef) * UInt256(x2_coef) + # Truncated first 29 digits + var truncated_prod_at_max_length = decimojo.utility.round_to_keep_first_n_digits( + prod, Decimal.MAX_NUM_DIGITS + ) + # Check outflow # The number of digits of the integral part var num_digits_of_integral_part = decimojo.utility.number_of_digits( prod ) - combined_scale - # Truncated first 29 digits - var truncated_prod_at_max_length = decimojo.utility.round_to_keep_first_n_digits( - prod, Decimal.MAX_NUM_DIGITS - ) + # Check for overflow of the integral part after rounding if (num_digits_of_integral_part >= Decimal.MAX_NUM_DIGITS) & ( truncated_prod_at_max_length > Decimal.MAX_AS_UINT256 @@ -686,7 +731,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: prod = truncated_prod_at_max_length # I think combined_scale should always be smaller - final_scale = min(num_digits_of_decimal_part, combined_scale) + var final_scale = min(num_digits_of_decimal_part, combined_scale) if final_scale > Decimal.MAX_SCALE: var ndigits_prod = decimojo.utility.number_of_digits(prod) diff --git a/src/decimojo/utility.mojo b/src/decimojo/utility.mojo index 95aa6a2..15dd5bb 100644 --- a/src/decimojo/utility.mojo +++ b/src/decimojo/utility.mojo @@ -25,6 +25,7 @@ # ===----------------------------------------------------------------------=== # from memory import UnsafePointer +import time from decimojo.decimal import Decimal @@ -128,7 +129,7 @@ fn scale_up(value: Decimal, owned level: Int) raises -> Decimal: # TODO: Check if multiplication by 10^level would cause overflow # If yes, then raise an error # - var max_coefficient = ~UInt128(0) / UInt128(10**level) + var max_coefficient = ~UInt128(0) / UInt128(10) ** level if coefficient > max_coefficient: # Handle overflow case - limit to maximum value or raise error coefficient = ~UInt128(0) @@ -190,7 +191,7 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: var digits_to_remove = ndigits - Decimal.MAX_NUM_DIGITS # Collect digits for rounding decision - var divisor = ValueType(10) ** ValueType(digits_to_remove) + var divisor = power_of_10[dtype](digits_to_remove) var truncated_value = value // divisor if truncated_value == ValueType(Decimal.MAX_AS_UINT128): @@ -204,7 +205,9 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: var remainder = value % divisor # Get the most significant digit of the remainder for rounding - var rounding_digit = remainder // 10 ** (digits_to_remove - 1) + var rounding_digit = remainder // power_of_10[dtype]( + digits_to_remove - 1 + ) # Check if we need to round up based on banker's rounding (ROUND_HALF_EVEN) var round_up = False @@ -214,7 +217,7 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: round_up = True # If rounding digit is 5, check if there are any non-zero digits after it elif rounding_digit == 5: - var has_nonzero_after = remainder > 5 * 10 ** ( + var has_nonzero_after = remainder > 5 * power_of_10[dtype]( digits_to_remove - 1 ) # If there are non-zero digits after, round up @@ -245,12 +248,14 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: digits_to_remove += 1 # Collect digits for rounding decision - divisor = ValueType(10) ** ValueType(digits_to_remove) + divisor = power_of_10[dtype](digits_to_remove) truncated_value = value // divisor var remainder = value % divisor # Get the most significant digit of the remainder for rounding - var rounding_digit = remainder // 10 ** (digits_to_remove - 1) + var rounding_digit = remainder // power_of_10[dtype]( + digits_to_remove - 1 + ) # Check if we need to round up based on banker's rounding (ROUND_HALF_EVEN) var round_up = False @@ -260,7 +265,7 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: round_up = True # If rounding digit is 5, check if there are any non-zero digits after it elif rounding_digit == 5: - var has_nonzero_after = remainder > 5 * 10 ** ( + var has_nonzero_after = remainder > 5 * power_of_10[dtype]( digits_to_remove - 1 ) # If there are non-zero digits after, round up @@ -366,7 +371,8 @@ fn round_to_keep_first_n_digits[ if ndigits < 0: return 0 - var ndigits_of_x = number_of_digits(value) + var ndigits_of_x: Int + ndigits_of_x = number_of_digits(value) # CASE: If the number of digits is greater than or equal to the specified digits # Return the value. @@ -387,7 +393,7 @@ fn round_to_keep_first_n_digits[ var ndigits_to_remove = ndigits_of_x - ndigits # Collect digits for rounding decision - var divisor = ValueType(10) ** ValueType(ndigits_to_remove) + var divisor = power_of_10[dtype](ndigits_to_remove) var truncated_value = value // divisor var remainder = value % divisor @@ -402,13 +408,13 @@ fn round_to_keep_first_n_digits[ # If RoundingMode is ROUND_HALF_UP, round up the value if remainder is greater than 5 elif rounding_mode == RoundingMode.ROUND_HALF_UP: - var cutoff_value = 5 * 10 ** (ndigits_to_remove - 1) + var cutoff_value = 5 * power_of_10[dtype](ndigits_to_remove - 1) if remainder >= cutoff_value: truncated_value += 1 # If RoundingMode is ROUND_HALF_EVEN, round to nearest even digit if equidistant else: - var cutoff_value: ValueType = 5 * ValueType(10) ** ( + var cutoff_value: ValueType = 5 * power_of_10[dtype]( ndigits_to_remove - 1 ) if remainder > cutoff_value: @@ -423,28 +429,171 @@ fn round_to_keep_first_n_digits[ return truncated_value -fn number_of_digits[dtype: DType, //](owned value: Scalar[dtype]) -> Int: +@always_inline +fn number_of_digits[dtype: DType](value: Scalar[dtype]) -> Int: """ - Returns the number of (significant) digits in an intergral value. + Returns the number of (significant) digits in an integral value using binary search. + This implementation is significantly faster than loop division. + + Parameters: + dtype: The Mojo scalar type to calculate the number of digits for. + + Args: + value: The integral value to calculate the number of digits for. Constraints: - `dtype` must be integral. + `dtype` must be either `DType.uint128` or `DType.uint256`. + + Returns: + The number of digits in the integral value. """ constrained[ - dtype.is_integral(), - "must be intergral", + dtype == DType.uint128 or dtype == DType.uint256, + "must be uint128 or uint256", ]() - if value < 0: - value = -value - - var count = 0 - while value > 0: - value //= 10 - count += 1 + alias ValueType = Scalar[dtype] - return count + # Handle edge cases + if value == 0: + return 0 + # Binary search to determine the number of digits + # First check small numbers with direct comparison (most common case) + if value < 10: + return 1 + if value < 100: + return 2 + if value < 1000: + return 3 + if value < 10000: + return 4 + if value < 100000: + return 5 + if value < 1000000: + return 6 + if value < 10000000: + return 7 + if value < 100000000: + return 8 + if value < 1000000000: + return 9 + + # For larger numbers, use binary search with limited indentation + # Medium range: 10^10 to 10^19 + if value < ValueType(10) ** 19: # < 10^19 + if value < ValueType(10) ** 13: # < 10^13 + if value < ValueType(10) ** 10: # < 10^10 + return 10 + if value < ValueType(10) ** 11: # < 10^11 + return 11 + if value < ValueType(10) ** 12: # < 10^12 + return 12 + return 13 + if value < ValueType(10) ** 16: # < 10^16 + if value < ValueType(10) ** 14: # < 10^14 + return 14 + if value < ValueType(10) ** 15: # < 10^15 + return 15 + return 16 + if value < ValueType(10) ** 17: # < 10^17 + return 17 + if value < ValueType(10) ** 18: # < 10^18 + return 18 + return 19 + + # Large range: 10^19 to 10^38 (UInt128 max is ~10^38) + if value < ValueType(10) ** 37: # < 10^37 + if value < ValueType(10) ** 28: # < 10^28 + if value < ValueType(10) ** 22: # < 10^22 + if value < ValueType(10) ** 20: # < 10^20 + return 20 + if value < ValueType(10) ** 21: # < 10^21 + return 21 + return 22 + if value < ValueType(10) ** 24: # < 10^24 + if value < ValueType(10) ** 23: # < 10^23 + return 23 + return 24 + if value < ValueType(10) ** 25: # < 10^25 + return 25 + if value < ValueType(10) ** 26: # < 10^26 + return 26 + if value < ValueType(10) ** 27: # < 10^27 + return 27 + return 28 + if value < ValueType(10) ** 31: # < 10^31 + if value < ValueType(10) ** 29: # < 10^29 + return 29 + if value < ValueType(10) ** 30: # < 10^30 + return 30 + return 31 + if value < ValueType(10) ** 33: # < 10^33 + if value < ValueType(10) ** 32: # < 10^32 + return 32 + return 33 + if value < ValueType(10) ** 34: # < 10^34 + return 34 + if value < ValueType(10) ** 35: # < 10^35 + return 35 + if value < ValueType(10) ** 36: # < 10^36 + return 36 + return 37 + + # Very large range: 10^37 to 10^77 (UInt256 max is ~10^77) + if value < ValueType(10) ** 38: # < 10^38 + return 38 + + # For UInt128, the maximum number of digits is 39 + # We can already return the result here + if dtype == DType.uint128: + return 39 + + if value < ValueType(10) ** 39: # < 10^39 + return 39 + + # Use additional binary searches for UInt256 range (10^39 to 10^77) + if value < ValueType(10) ** 58: # < 10^58 + if value < ValueType(10) ** 47: # < 10^47 + if value < ValueType(10) ** 43: # < 10^43 + if value < ValueType(10) ** 40: # < 10^40 + return 40 + if value < ValueType(10) ** 41: # < 10^41 + return 41 + if value < ValueType(10) ** 42: # < 10^42 + return 42 + return 43 + if value < ValueType(10) ** 44: # < 10^44 + return 44 + if value < ValueType(10) ** 45: # < 10^45 + return 45 + if value < ValueType(10) ** 46: # < 10^46 + return 46 + return 47 + if value < ValueType(10) ** 52: # < 10^52 + if value < ValueType(10) ** 48: # < 10^48 + return 48 + if value < ValueType(10) ** 49: # < 10^49 + return 49 + if value < ValueType(10) ** 50: # < 10^50 + return 50 + if value < ValueType(10) ** 51: # < 10^51 + return 51 + return 52 + if value < ValueType(10) ** 54: # < 10^54 + if value < ValueType(10) ** 53: # < 10^53 + return 53 + return 54 + if value < ValueType(10) ** 56: # < 10^56 + if value < ValueType(10) ** 55: # < 10^55 + return 55 + return 56 + if value < ValueType(10) ** 57: # < 10^57 + return 57 + return 58 + + # Digits more than 58 is not possible for Decimal products + return 59 fn number_of_bits[dtype: DType, //](owned value: Scalar[dtype]) -> Int: @@ -554,3 +703,153 @@ fn power_of_10_as_uint256(n: Int) raises -> UInt256: _power_of_10_as_uint256_cache.append(next_power) return _power_of_10_as_uint256_cache[n] + + +@always_inline +fn power_of_10[dtype: DType](n: Int) -> Scalar[dtype]: + """ + Returns 10^n using cached values when available. + **WARNING**: The overflow is not checked in this function. + Make sure that the n is less than 29 for UInt128 and 77 for UInt256. + + Parameters: + dtype: The Mojo scalar type to calculate the power of 10 for. + + Args: + n: The exponent to raise 10 to. + + Constraints: + `dtype` must be either `DType.uint128` or `DType.uint256`. + + Returns: + The value of 10^n as a Mojo scalar. + + Notes: + The powers of 10 is hard-coded up to 10^56 since it is twice the maximum + scale of Decimal (28). For larger values, the function calculates the + power of 10 using the built-in `**` operator. + """ + + alias ValueType = Scalar[dtype] + + constrained[ + dtype == DType.uint128 or dtype == DType.uint256, + "must be uint128 or uint256", + ]() + + if n == 0: + return ValueType(1) + if n == 1: + return ValueType(10) + if n == 2: + return ValueType(100) + if n == 3: + return ValueType(1000) + if n == 4: + return ValueType(10000) + if n == 5: + return ValueType(100000) + if n == 6: + return ValueType(1000000) + if n == 7: + return ValueType(10000000) + if n == 8: + return ValueType(100000000) + if n == 9: + return ValueType(1000000000) + if n == 10: + return ValueType(10000000000) + if n == 11: + return ValueType(100000000000) + if n == 12: + return ValueType(1000000000000) + if n == 13: + return ValueType(10000000000000) + if n == 14: + return ValueType(100000000000000) + if n == 15: + return ValueType(1000000000000000) + if n == 16: + return ValueType(10000000000000000) + if n == 17: + return ValueType(100000000000000000) + if n == 18: + return ValueType(1000000000000000000) + if n == 19: + return ValueType(10000000000000000000) + if n == 20: + return ValueType(100000000000000000000) + if n == 21: + return ValueType(1000000000000000000000) + if n == 22: + return ValueType(10000000000000000000000) + if n == 23: + return ValueType(100000000000000000000000) + if n == 24: + return ValueType(1000000000000000000000000) + if n == 25: + return ValueType(10000000000000000000000000) + if n == 26: + return ValueType(100000000000000000000000000) + if n == 27: + return ValueType(1000000000000000000000000000) + if n == 28: + return ValueType(10000000000000000000000000000) + if n == 29: + return ValueType(100000000000000000000000000000) + if n == 30: + return ValueType(1000000000000000000000000000000) + if n == 31: + return ValueType(10000000000000000000000000000000) + if n == 32: + return ValueType(100000000000000000000000000000000) + if n == 33: + return ValueType(10) ** 33 + if n == 34: + return ValueType(10) ** 34 + if n == 35: + return ValueType(10) ** 35 + if n == 36: + return ValueType(10) ** 36 + if n == 37: + return ValueType(10) ** 37 + if n == 38: + return ValueType(10) ** 38 + if n == 39: + return ValueType(10) ** 39 + if n == 40: + return ValueType(10) ** 40 + if n == 41: + return ValueType(10) ** 41 + if n == 42: + return ValueType(10) ** 42 + if n == 43: + return ValueType(10) ** 43 + if n == 44: + return ValueType(10) ** 44 + if n == 45: + return ValueType(10) ** 45 + if n == 46: + return ValueType(10) ** 46 + if n == 47: + return ValueType(10) ** 47 + if n == 48: + return ValueType(10) ** 48 + if n == 49: + return ValueType(10) ** 49 + if n == 50: + return ValueType(10) ** 50 + if n == 51: + return ValueType(10) ** 51 + if n == 52: + return ValueType(10) ** 52 + if n == 53: + return ValueType(10) ** 53 + if n == 54: + return ValueType(10) ** 54 + if n == 55: + return ValueType(10) ** 55 + if n == 56: + return ValueType(10) ** 56 + + return ValueType(10) ** n diff --git a/tests/test_sqrt.mojo b/tests/test_sqrt.mojo index d8e53a6..35b660d 100644 --- a/tests/test_sqrt.mojo +++ b/tests/test_sqrt.mojo @@ -823,32 +823,18 @@ fn test_sqrt_performance() raises: # Test case 7 var num7 = Decimal("0.999999999") - var result7 = sqrt(num7) - var squared7 = result7 * result7 - var diff7 = squared7 - num7 - diff7 = -diff7 if diff7.is_negative() else diff7 - var rel_diff7 = diff7 / num7 - var diff_float7 = Float64(String(rel_diff7)) + var result7 = String(sqrt(num7)) + var expected_result7 = String("0.99999999949999999987") testing.assert_true( - diff_float7 < 0.00001, - "Square root calculation for " - + String(num7) - + " should be accurate within 0.001%", + result7.startswith(expected_result7), "sqrt(0.999999999)" ) # Test case 8 var num8 = Decimal("1.000000001") - var result8 = sqrt(num8) - var squared8 = result8 * result8 - var diff8 = squared8 - num8 - diff8 = -diff8 if diff8.is_negative() else diff8 - var rel_diff8 = diff8 / num8 - var diff_float8 = Float64(String(rel_diff8)) + var result8 = String(sqrt(num8)) + var expected_result8 = String("1.000000000499999999875") testing.assert_true( - diff_float8 < 0.00001, - "Square root calculation for " - + String(num8) - + " should be accurate within 0.001%", + result8.startswith(expected_result8), "sqrt(1.000000001)" ) # Test case 9