diff --git a/README.md b/README.md index 7c39fb1..0bcf52a 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,12 @@ The core types are: This repository includes [TOMLMojo](https://github.com/forfudan/decimojo/tree/main/src/tomlmojo), a lightweight TOML parser in pure Mojo. It parses configuration files and test data, supporting basic types, arrays, and nested tables. While created for DeciMojo's testing framework, it offers general-purpose structured data parsing with a clean, simple API. -| type | information | internal representation | -| ------------ | ------------------------------------ | ------------------------ | -| `BigUInt` | arbitrary-precision unsigned integer | `List[UInt32]` | -| `BigInt` | arbitrary-precision integer | `BigUInt`, `Bool` | -| `Decimal` | 128-bit fixed-precision decimal | 4 `UInt32` words | -| `BigDecimal` | arbitrary-precision decimal | `BigUInt`, `Int`, `Bool` | +| type | alias | information | internal representation | +| ------------ | ------- | ------------------------------------ | ----------------------------------- | +| `BigUInt` | `BUInt` | arbitrary-precision unsigned integer | `List[UInt32]` | +| `BigInt` | `BInt` | arbitrary-precision integer | `BigUInt`, `Bool` | +| `Decimal` | `Dec` | 128-bit fixed-precision decimal | `UInt32`,`UInt32`,`UInt32`,`UInt32` | +| `BigDecimal` | `BDec` | arbitrary-precision decimal | `BigUInt`, `Int`, `Bool` | ## Installation diff --git a/benches/bigdecimal/bench.mojo b/benches/bigdecimal/bench.mojo index a6dffd9..5a308cc 100644 --- a/benches/bigdecimal/bench.mojo +++ b/benches/bigdecimal/bench.mojo @@ -2,6 +2,7 @@ from bench_bigdecimal_add import main as bench_add from bench_bigdecimal_subtract import main as bench_sub from bench_bigdecimal_multiply import main as bench_multiply from bench_bigdecimal_divide import main as bench_divide +from bench_bigdecimal_sqrt import main as bench_sqrt from bench_bigdecimal_scale_up_by_power_of_10 import main as bench_scale_up @@ -15,6 +16,7 @@ add: Add sub: Subtract mul: Multiply div: Divide (true divide) +sqrt: Square root all: Run all benchmarks q: Exit ========================================= @@ -31,11 +33,14 @@ scaleup: Scale up by power of 10 bench_multiply() elif command == "div": bench_divide() + elif command == "sqrt": + bench_sqrt() elif command == "all": bench_add() bench_sub() bench_multiply() bench_divide() + bench_sqrt() elif command == "q": return elif command == "scaleup": diff --git a/benches/bigdecimal/bench_bigdecimal_sqrt.mojo b/benches/bigdecimal/bench_bigdecimal_sqrt.mojo new file mode 100644 index 0000000..84b145c --- /dev/null +++ b/benches/bigdecimal/bench_bigdecimal_sqrt.mojo @@ -0,0 +1,663 @@ +""" +Comprehensive benchmarks for BigDecimal square root. +Compares performance against Python's decimal module with 50 diverse test cases. +""" + +from decimojo import BigDecimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_bigdecimal_sqrt_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark_sqrt( + name: String, + value: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo BigDecimal square root with Python Decimal square root. + + Args: + name: Name of the benchmark case. + value: String representation of the number to calculate the square root of. + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Value: " + value, log_file) + + # Set up Mojo and Python values + var mojo_value = BigDecimal(value) + var pydecimal = Python.import_module("decimal") + var py_value = pydecimal.Decimal(value) + + # Execute the operations once to verify correctness + try: + var mojo_result = mojo_value.sqrt() + var py_result = py_value.sqrt() + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = mojo_value.sqrt() + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_value.sqrt() + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo square root: " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python square root: " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + except e: + log_print("Error occurred during benchmark: " + String(e), log_file) + log_print("Skipping this benchmark case", log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo BigDecimal Square Root Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 1 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: 28", log_file) + + # Define benchmark cases + log_print( + "\nRunning decimal square root benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # === BASIC SQUARE ROOT TESTS === + + # Case 1: Simple integer square root + run_benchmark_sqrt( + "Simple integer square root", + "9", + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Simple decimal square root + run_benchmark_sqrt( + "Simple decimal square root", + "2.25", + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Square root with different scales + run_benchmark_sqrt( + "Square root with different scales", + "1.5625", + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Square root with very different scales + run_benchmark_sqrt( + "Square root with very different scales", + "0.0001", + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Square root of one + run_benchmark_sqrt( + "Square root of one", + "1", + iterations, + log_file, + speedup_factors, + ) + + # === SCALE AND PRECISION TESTS === + + # Case 6: Precision at decimal limit + run_benchmark_sqrt( + "Precision at decimal limit", + "2.0000000000000000000000000000", + iterations, + log_file, + speedup_factors, + ) + + # Case 7: Square root resulting in scale increase + run_benchmark_sqrt( + "Square root resulting in scale increase", + "0.01", + iterations, + log_file, + speedup_factors, + ) + + # Case 8: Square root with high precision + run_benchmark_sqrt( + "Square root with high precision", + "0.1111111111111111111111111111", + iterations, + log_file, + speedup_factors, + ) + + # Case 9: Square root resulting in exact integer + run_benchmark_sqrt( + "Square root resulting in exact integer", + "4", + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Square root with scientific notation + run_benchmark_sqrt( + "Square root with scientific notation", + "1.44e2", + iterations, + log_file, + speedup_factors, + ) + + # === LARGE NUMBER TESTS === + + # Case 11: Large integer square root + run_benchmark_sqrt( + "Large integer square root", + "9999999", + iterations, + log_file, + speedup_factors, + ) + + # Case 12: Large decimal square root + run_benchmark_sqrt( + "Large decimal square root", + "12345.6789", + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Very large square root + run_benchmark_sqrt( + "Very large square root", + "1" + "0" * 20, # 10^20 + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Extreme scales (large positive exponents) + run_benchmark_sqrt( + "Extreme scales (large positive exponents)", + "1.44e10", + iterations, + log_file, + speedup_factors, + ) + + # === SMALL NUMBER TESTS === + + # Case 15: Very small positive values + run_benchmark_sqrt( + "Very small positive values", + "0." + "0" * 15 + "1", + iterations, + log_file, + speedup_factors, + ) + + # Case 16: Extreme scales (large negative exponents) + run_benchmark_sqrt( + "Extreme scales (large negative exponents)", + "1.44e-10", + iterations, + log_file, + speedup_factors, + ) + + # === SPECIAL VALUE TESTS === + + # Case 17: Square root of exact mathematical constants + run_benchmark_sqrt( + "Square root of exact mathematical constants (PI)", + "3.14159265358979323846264338328", + iterations, + log_file, + speedup_factors, + ) + + # Case 18: Square root of 2 + run_benchmark_sqrt( + "Square root of 2", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 19: Square root of 3 + run_benchmark_sqrt( + "Square root of 3", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 20: Square root of 5 + run_benchmark_sqrt( + "Square root of 5", + "5", + iterations, + log_file, + speedup_factors, + ) + + # Case 21: Square root of 7 + run_benchmark_sqrt( + "Square root of 7", + "7", + iterations, + log_file, + speedup_factors, + ) + + # Case 22: Square root of 11 + run_benchmark_sqrt( + "Square root of 11", + "11", + iterations, + log_file, + speedup_factors, + ) + + # Case 23: Square root of 13 + run_benchmark_sqrt( + "Square root of 13", + "13", + iterations, + log_file, + speedup_factors, + ) + + # Case 24: Square root of 17 + run_benchmark_sqrt( + "Square root of 17", + "17", + iterations, + log_file, + speedup_factors, + ) + + # Case 25: Square root of 19 + run_benchmark_sqrt( + "Square root of 19", + "19", + iterations, + log_file, + speedup_factors, + ) + + # Case 26: Square root of 23 + run_benchmark_sqrt( + "Square root of 23", + "23", + iterations, + log_file, + speedup_factors, + ) + + # Case 27: Square root of 29 + run_benchmark_sqrt( + "Square root of 29", + "29", + iterations, + log_file, + speedup_factors, + ) + + # Case 28: Square root of 31 + run_benchmark_sqrt( + "Square root of 31", + "31", + iterations, + log_file, + speedup_factors, + ) + + # Case 29: Square root of 37 + run_benchmark_sqrt( + "Square root of 37", + "37", + iterations, + log_file, + speedup_factors, + ) + + # Case 30: Square root of 41 + run_benchmark_sqrt( + "Square root of 41", + "41", + iterations, + log_file, + speedup_factors, + ) + + # Case 31: Square root of 43 + run_benchmark_sqrt( + "Square root of 43", + "43", + iterations, + log_file, + speedup_factors, + ) + + # Case 32: Square root of 47 + run_benchmark_sqrt( + "Square root of 47", + "47", + iterations, + log_file, + speedup_factors, + ) + + # Case 33: Square root of 53 + run_benchmark_sqrt( + "Square root of 53", + "53", + iterations, + log_file, + speedup_factors, + ) + + # Case 34: Square root of 59 + run_benchmark_sqrt( + "Square root of 59", + "59", + iterations, + log_file, + speedup_factors, + ) + + # Case 35: Square root of 61 + run_benchmark_sqrt( + "Square root of 61", + "61", + iterations, + log_file, + speedup_factors, + ) + + # Case 36: Square root of 67 + run_benchmark_sqrt( + "Square root of 67", + "67", + iterations, + log_file, + speedup_factors, + ) + + # Case 37: Square root of 71 + run_benchmark_sqrt( + "Square root of 71", + "71", + iterations, + log_file, + speedup_factors, + ) + + # Case 38: Square root of 73 + run_benchmark_sqrt( + "Square root of 73", + "73", + iterations, + log_file, + speedup_factors, + ) + + # Case 39: Square root of 79 + run_benchmark_sqrt( + "Square root of 79", + "79", + iterations, + log_file, + speedup_factors, + ) + + # Case 40: Square root of 83 + run_benchmark_sqrt( + "Square root of 83", + "83", + iterations, + log_file, + speedup_factors, + ) + + # Case 41: Square root of 89 + run_benchmark_sqrt( + "Square root of 89", + "89", + iterations, + log_file, + speedup_factors, + ) + + # Case 42: Square root of 97 + run_benchmark_sqrt( + "Square root of 97", + "97", + iterations, + log_file, + speedup_factors, + ) + + # Case 43: Square root of 101 + run_benchmark_sqrt( + "Square root of 101", + "101", + iterations, + log_file, + speedup_factors, + ) + + # Case 44: Square root of 103 + run_benchmark_sqrt( + "Square root of 103", + "103", + iterations, + log_file, + speedup_factors, + ) + + # Case 45: Square root of 107 + run_benchmark_sqrt( + "Square root of 107", + "107", + iterations, + log_file, + speedup_factors, + ) + + # Case 46: Square root of 109 + run_benchmark_sqrt( + "Square root of 109", + "109", + iterations, + log_file, + speedup_factors, + ) + + # Case 47: Square root of 113 + run_benchmark_sqrt( + "Square root of 113", + "113", + iterations, + log_file, + speedup_factors, + ) + + # Case 48: Square root of 127 + run_benchmark_sqrt( + "Square root of 127", + "127", + iterations, + log_file, + speedup_factors, + ) + + # Case 49: Square root of 131 + run_benchmark_sqrt( + "Square root of 131", + "131", + iterations, + log_file, + speedup_factors, + ) + + # Case 50: Square root of 137 + run_benchmark_sqrt( + "Square root of 137", + "137", + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor (ignoring any cases that might have failed) + if len(speedup_factors) > 0: + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print( + "\n=== BigDecimal Square Root Benchmark Summary ===", log_file + ) + log_print( + "Benchmarked: " + + String(len(speedup_factors)) + + " different square root cases", + log_file, + ) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print( + "Average speedup: " + String(average_speedup) + "×", log_file + ) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format( + i + 1, round(speedup_factors[i], 2) + ), + log_file, + ) + else: + log_print("\nNo valid benchmark cases were completed", log_file) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/src/decimojo/bigdecimal/bigdecimal.mojo b/src/decimojo/bigdecimal/bigdecimal.mojo index f2694c8..f26db54 100644 --- a/src/decimojo/bigdecimal/bigdecimal.mojo +++ b/src/decimojo/bigdecimal/bigdecimal.mojo @@ -541,10 +541,16 @@ struct BigDecimal: """Returns the maximum of two BigDecimal numbers.""" return decimojo.bigdecimal.comparison.max(self, other) + @always_inline fn min(self, other: Self) raises -> Self: """Returns the minimum of two BigDecimal numbers.""" return decimojo.bigdecimal.comparison.min(self, other) + @always_inline + fn sqrt(self, precision: Int = 28) raises -> Self: + """Returns the square root of the BigDecimal number.""" + return decimojo.bigdecimal.exponential.sqrt(self, precision) + @always_inline fn true_divide(self, other: Self, precision: Int) raises -> Self: """Returns the result of true division of two BigDecimal numbers. @@ -669,6 +675,11 @@ struct BigDecimal: ) print("----------------------------------------------") + @always_inline + fn is_negative(self) -> Bool: + """Returns True if this number represents a negative value.""" + return self.sign + @always_inline fn is_zero(self) -> Bool: """Returns True if this number represents zero.""" diff --git a/src/decimojo/bigdecimal/exponential.mojo b/src/decimojo/bigdecimal/exponential.mojo new file mode 100644 index 0000000..acb11eb --- /dev/null +++ b/src/decimojo/bigdecimal/exponential.mojo @@ -0,0 +1,159 @@ +# ===----------------------------------------------------------------------=== # +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Implements exponential functions for the BigDecimal type.""" + +from decimojo.bigdecimal.bigdecimal import BigDecimal +from decimojo.rounding_mode import RoundingMode +import decimojo.utility + + +fn sqrt(x: BigDecimal, precision: Int = 28) raises -> BigDecimal: + """Calculate the square root of a BigDecimal number. + + Args: + x: The number to calculate the square root of. + precision: The desired precision (number of significant digits) of the result. + + Returns: + The square root of x with the specified precision. + + Raises: + Error: If x is negative. + """ + alias BUFFER_DIGITS = 9 + + # Handle special cases + if x.sign: + raise Error( + "Error in `sqrt`: Cannot compute square root of negative number" + ) + + if x.coefficient.is_zero(): + return BigDecimal(BigUInt.ZERO, (x.scale + 1) // 2, False) + + # Initial guess + # A decimal has coefficient and scale + # Example 1: + # 123456789012345678901234567890.12345 (sqrt ~= 351364182882014.4253111222382) + # coef = 12345678_901234567_890123456_789012345, scale = 5 + # first three words = 12345678_901234567_890123456 + # number of integral digits = 30 + # Because it is even, no need to scale up by 10 + # not scale up by 10 => 12345678901234567890123456 + # sqrt(12345678901234567890123456) = 3513641828820 + # number of integral digits of the sqrt = (30 + 1) // 2 = 15 + # coef = 3513641828820, 13 digits, so scale = 13 - 15 + # + # Example 2: + # 12345678901.234567890123456789012345 (sqrt ~= 111111.1106111111099361111058) + # coef = 12345678_901234567_890123456_789012345, scale = 24 + # first three words = 12345678_901234567_890123456 + # remaining number of words = 11 + # Because it is odd, need to scale up by 10 + # scale up by 10 => 123456789012345678901234560 + # sqrt(123456789012345678901234560) = 11111111061111 + # number of integral digits of the sqrt = (11 + 1) // 2 = 6 + # coef = 11111111061111, 14 digits, so scale = 14 - 6 => (111111.11061111) + + var guess: BigDecimal + var ndigits_coef = x.coefficient.number_of_digits() + var ndigits_int_part = x.coefficient.number_of_digits() - x.scale + var ndigits_int_part_sqrt = (ndigits_int_part + 1) // 2 + var odd_ndigits_frac_part = x.scale % 2 == 1 + + var value: UInt128 + if ndigits_coef <= 9: + value = UInt128(x.coefficient.words[0]) * UInt128( + 1_000_000_000_000_000_000 + ) + elif ndigits_coef <= 18: + value = ( + UInt128(x.coefficient.words[-1]) + * UInt128(1_000_000_000_000_000_000) + ) + (UInt128(x.coefficient.words[-2]) * UInt128(1_000_000_000)) + else: # ndigits_coef > 18 + value = ( + ( + UInt128(x.coefficient.words[-1]) + * UInt128(1_000_000_000_000_000_000) + ) + + UInt128(x.coefficient.words[-2]) * UInt128(1_000_000_000) + + UInt128(x.coefficient.words[-3]) + ) + if odd_ndigits_frac_part: + value = value * UInt128(10) + var sqrt_value = decimojo.utility.sqrt(value) + var sqrt_value_biguint = BigUInt.from_scalar(sqrt_value) + guess = BigDecimal( + sqrt_value_biguint, + sqrt_value_biguint.number_of_digits() - ndigits_int_part_sqrt, + False, + ) + + # For Newton's method, we need extra precision during calculations + # to ensure the final result has the desired precision + var working_precision = precision + BUFFER_DIGITS + + # Newton's method iterations + # x_{n+1} = (x_n + N/x_n) / 2 + var prev_guess = BigDecimal(BigUInt.ZERO, 0, False) + var iteration_count = 0 + + while guess != prev_guess and iteration_count < 100: + prev_guess = guess + var quotient = x.true_divide(guess, working_precision) + var sum = guess + quotient + guess = sum.true_divide(BigDecimal(BigUInt(2), 0, 0), working_precision) + iteration_count += 1 + + # Round to the desired precision + var ndigits_to_remove = guess.coefficient.number_of_digits() - precision + if ndigits_to_remove > 0: + var coefficient = guess.coefficient + coefficient = coefficient.remove_trailing_digits_with_rounding( + ndigits_to_remove, + rounding_mode=RoundingMode.ROUND_HALF_UP, + remove_extra_digit_due_to_rounding=True, + ) + guess.coefficient = coefficient^ + guess.scale -= ndigits_to_remove + + # Remove trailing zeros for exact results + # TODO: This can be done even earlier in the process + if guess.coefficient.ith_digit(0) == 0: + var guess_coefficient_without_trailing_zeros = guess.coefficient.remove_trailing_digits_with_rounding( + guess.coefficient.number_of_trailing_zeros() + ) + var x_coefficient_without_trailing_zeros = x.coefficient.remove_trailing_digits_with_rounding( + x.coefficient.number_of_trailing_zeros() + ) + if ( + guess_coefficient_without_trailing_zeros + * guess_coefficient_without_trailing_zeros + ) == x_coefficient_without_trailing_zeros: + var expected_ndigits_of_result = ( + x.coefficient.number_of_digits() + 1 + ) // 2 + guess.scale = (x.scale + 1) // 2 + guess.coefficient = ( + guess.coefficient.remove_trailing_digits_with_rounding( + guess.coefficient.number_of_digits() + - expected_ndigits_of_result + ) + ) + + return guess^ diff --git a/src/decimojo/decimal/exponential.mojo b/src/decimojo/decimal/exponential.mojo index 90e4fa9..967b7d6 100644 --- a/src/decimojo/decimal/exponential.mojo +++ b/src/decimojo/decimal/exponential.mojo @@ -1,7 +1,4 @@ # ===----------------------------------------------------------------------=== # -# DeciMojo: A fixed-point decimal arithmetic library in Mojo -# https://github.com/forfudan/decimojo -# # Copyright 2025 Yuhao Zhu # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/decimojo/utility.mojo b/src/decimojo/utility.mojo index 8b5c7f8..07020c7 100644 --- a/src/decimojo/utility.mojo +++ b/src/decimojo/utility.mojo @@ -20,6 +20,7 @@ # ===----------------------------------------------------------------------=== # from memory import UnsafePointer +import sys import time from decimojo.decimal.decimal import Decimal @@ -193,6 +194,32 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: return truncated_value +fn sqrt(x: UInt128) -> UInt128: + """ + Returns the square root of a UInt128 value. + + Args: + x: The UInt128 value to calculate the square root for. + + Returns: + The square root of the UInt128 value. + """ + + if x < 0: + return 0 + + var r: UInt128 = 0 + + for p in range(sys.bitwidthof[UInt128]() // 2 - 1, -1, -1): + var new_bit = UInt128(1) << p + var would_be = r | new_bit + var squared = would_be * would_be + if squared <= x: + r = would_be + + return r + + # TODO: Evaluate whether this can replace truncate_to_max in some cases. # TODO: Add rounding modes to this function. fn round_to_keep_first_n_digits[ diff --git a/tests/bigdecimal/test_bigdecimal_exponential.mojo b/tests/bigdecimal/test_bigdecimal_exponential.mojo new file mode 100644 index 0000000..d392b2c --- /dev/null +++ b/tests/bigdecimal/test_bigdecimal_exponential.mojo @@ -0,0 +1,120 @@ +""" +Test BigDecimal exponential operations including square root. +""" + +from python import Python +import testing + +from decimojo import BigDecimal, RoundingMode +from decimojo.tests import TestCase +from tomlmojo import parse_file + +alias exponential_file_path = "tests/bigdecimal/test_data/bigdecimal_exponential.toml" + + +fn load_test_cases( + file_path: String, table_name: String +) raises -> List[TestCase]: + """Load test cases from a TOML file for a specific table.""" + var toml = parse_file(file_path) + var test_cases = List[TestCase]() + + # Get array of test cases + var cases_array = toml.get_array_of_tables(table_name) + + for i in range(len(cases_array)): + var case_table = cases_array[i] + test_cases.append( + TestCase( + case_table["input"].as_string(), + "", # No second input for sqrt + case_table["expected"].as_string(), + case_table["description"].as_string(), + ) + ) + + return test_cases + + +fn test_sqrt() raises: + """Test BigDecimal square root with various test cases.""" + print("------------------------------------------------------") + print("Testing BigDecimal square root...") + + var pydecimal = Python.import_module("decimal") + + # Load test cases from TOML file + var test_cases = load_test_cases(exponential_file_path, "sqrt_tests") + print("Loaded", len(test_cases), "test cases for square root") + + # Track test results + var passed = 0 + var failed = 0 + + # Run all test cases in a loop + for i in range(len(test_cases)): + var test_case = test_cases[i] + var input_value = BigDecimal(test_case.a) + var expected = BigDecimal(test_case.expected) + + # Calculate square root + var result = input_value.sqrt() + + try: + # Using String comparison for easier debugging + testing.assert_equal( + String(result), String(expected), test_case.description + ) + passed += 1 + except e: + print( + "=" * 50, + "\n", + i + 1, + "failed:", + test_case.description, + "\n Input:", + test_case.a, + "\n Expected:", + test_case.expected, + "\n Got:", + String(result), + "\n Python decimal result (for reference):", + String(pydecimal.Decimal(test_case.a).sqrt()), + ) + failed += 1 + + print("BigDecimal sqrt tests:", passed, "passed,", failed, "failed") + testing.assert_equal(failed, 0, "All square root tests should pass") + + +fn test_negative_sqrt() raises: + """Test that square root of negative number raises an error.""" + print("------------------------------------------------------") + print("Testing BigDecimal square root with negative input...") + + var negative_number = BigDecimal("-1") + + var exception_caught = False + try: + _ = negative_number.sqrt() + exception_caught = False + except: + exception_caught = True + + testing.assert_true( + exception_caught, "Square root of negative number should raise an error" + ) + print("✓ Square root of negative number correctly raises an error") + + +fn main() raises: + print("Running BigDecimal exponential tests") + + # Run sqrt tests + test_sqrt() + + # Test sqrt of negative number + test_negative_sqrt() + + print("All BigDecimal exponential tests passed!") diff --git a/tests/bigdecimal/test_data/bigdecimal_exponential.toml b/tests/bigdecimal/test_data/bigdecimal_exponential.toml new file mode 100644 index 0000000..6a9ac5c --- /dev/null +++ b/tests/bigdecimal/test_data/bigdecimal_exponential.toml @@ -0,0 +1,152 @@ +# === BASIC SQUARE ROOT TESTS === +[[sqrt_tests]] +input = "0" +expected = "0" +description = "Square root of zero" + +[[sqrt_tests]] +input = "1" +expected = "1" +description = "Square root of one" + +[[sqrt_tests]] +input = "4" +expected = "2" +description = "Square root of perfect square (small)" + +[[sqrt_tests]] +input = "9" +expected = "3" +description = "Square root of perfect square (single digit)" + +[[sqrt_tests]] +input = "16" +expected = "4" +description = "Square root of perfect square (16)" + +[[sqrt_tests]] +input = "25" +expected = "5" +description = "Square root of perfect square (25)" + +[[sqrt_tests]] +input = "100" +expected = "10" +description = "Square root of perfect square (100)" + +# === NON-PERFECT SQUARES === +[[sqrt_tests]] +input = "2" +expected = "1.414213562373095048801688724" +description = "Square root of 2 (irrational)" + +[[sqrt_tests]] +input = "3" +expected = "1.732050807568877293527446342" +description = "Square root of 3 (irrational)" + +[[sqrt_tests]] +input = "5" +expected = "2.236067977499789696409173669" +description = "Square root of 5 (irrational)" + +[[sqrt_tests]] +input = "10" +expected = "3.162277660168379331998893544" +description = "Square root of 10 (irrational)" + +# === DECIMAL INPUTS === +[[sqrt_tests]] +input = "0.25" +expected = "0.5" +description = "Square root of 0.25" + +[[sqrt_tests]] +input = "0.01" +expected = "0.1" +description = "Square root of 0.01" + +[[sqrt_tests]] +input = "0.0625" +expected = "0.25" +description = "Square root of 0.0625" + +[[sqrt_tests]] +input = "2.25" +expected = "1.5" +description = "Square root of 2.25" + +[[sqrt_tests]] +input = "12.25" +expected = "3.5" +description = "Square root of 12.25" + +# === LARGE NUMBERS === +[[sqrt_tests]] +input = "1000000" +expected = "1000" +description = "Square root of large perfect square" + +[[sqrt_tests]] +input = "9999999999" +expected = "99999.99999499999999987500000" +description = "Square root of large near-perfect square" + +[[sqrt_tests]] +input = "1000000000000000000000000000" +expected = "31622776601683.79331998893544" +description = "Square root of very large imperfect square" + +# === SMALL NUMBERS === +[[sqrt_tests]] +input = "0.0001" +expected = "0.01" +description = "Square root of small number" + +[[sqrt_tests]] +input = "0.000000000001" +expected = "0.000001" +description = "Square root of very small number" + +[[sqrt_tests]] +input = "0.0000000000000000000000000001" +expected = "1E-14" +description = "Square root of extremely small number" + +# === SCIENTIFIC NOTATION === +[[sqrt_tests]] +input = "1e10" +expected = "1E+5" +description = "Square root with scientific notation (positive exponent)" + +[[sqrt_tests]] +input = "1e-10" +expected = "0.00001" +description = "Square root with scientific notation (negative exponent)" + +# === PRECISION TESTS === +[[sqrt_tests]] +input = "2" +expected = "1.414213562373095048801688724" +description = "High precision square root of 2" + +[[sqrt_tests]] +input = "0.9999999999999999" +expected = "0.9999999999999999500000000000" +description = "Square root slightly less than 1" + +[[sqrt_tests]] +input = "1.0000000000000001" +expected = "1.000000000000000050000000000" +description = "Square root slightly more than 1" + +# === APPLICATION SCENARIOS === +[[sqrt_tests]] +input = "3.14159265358979323846" +expected = "1.772453850905516027297421799" +description = "Square root of π" + +[[sqrt_tests]] +input = "2.71828182845904523536" +expected = "1.648721270700128146848563608" +description = "Square root of e"