diff --git a/benches/bench_root.mojo b/benches/bench_root.mojo new file mode 100644 index 0000000..6953c9e --- /dev/null +++ b/benches/bench_root.mojo @@ -0,0 +1,401 @@ +""" +Comprehensive benchmarks for Decimal nth root function (root). +Compares performance against Python's decimal module with diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_root_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark( + name: String, + value: String, + nth_root: Int, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal root with Python Decimal power(1/n). + + Args: + name: Name of the benchmark case. + value: String representation of the number to find the nth root of. + nth_root: The root value (2 for square root, 3 for cube root, etc.). + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Value: " + value, log_file) + log_print("Root: " + String(nth_root), log_file) + + # Set up Mojo and Python values + var mojo_decimal = Decimal(value) + var pydecimal = Python.import_module("decimal") + var py_decimal = pydecimal.Decimal(value) + var py_root = pydecimal.Decimal(String(nth_root)) + var py_frac = pydecimal.Decimal(1) / py_root + + # Special case: Python can't directly compute odd root of negative number + var is_negative_odd_root = value.startswith("-") and nth_root % 2 == 1 + var py_result: PythonObject + + # Execute the operations once to verify correctness + var mojo_result = dm.exponential.root(mojo_decimal, nth_root) + + # Handle Python calculation, accounting for negative odd root limitation + if is_negative_odd_root: + # For negative numbers with odd roots in Python, we need to: + # 1. Take absolute value + # 2. Compute the root + # 3. Negate the result + var abs_py_decimal = py_decimal.copy_abs() + py_result = -(abs_py_decimal**py_frac) + log_print( + ( + "Note: Python doesn't directly support odd roots of negative" + " numbers." + ), + log_file, + ) + log_print( + " Using abs() and then negating the result for comparison.", + log_file, + ) + else: + try: + py_result = py_decimal**py_frac + except: + log_print( + "Python cannot compute this root. Skipping Python benchmark.", + log_file, + ) + py_result = Python.evaluate( + "None" + ) # Correct way to get Python's None + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + if not ( + py_result is Python.evaluate("None") + ): # Correct way to check for None + log_print("Python result: " + String(py_result), log_file) + else: + log_print("Python result: ERROR - cannot compute", log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = dm.exponential.root(mojo_decimal, nth_root) + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation (if possible) + var python_time: Float64 = 0 + if not is_negative_odd_root and not ( + py_result is Python.evaluate("None") + ): # Correct way to check for None + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_decimal**py_frac + python_time = (perf_counter_ns() - t0) / iterations + elif is_negative_odd_root: + # For negative numbers with odd roots, benchmark our workaround + var abs_py_decimal = py_decimal.copy_abs() + t0 = perf_counter_ns() + for _ in range(iterations): + _ = -(abs_py_decimal**py_frac) + python_time = (perf_counter_ns() - t0) / iterations + else: + log_print("Python benchmark skipped", log_file) + python_time = 0 + + # Calculate speedup factor (if Python benchmark ran) + if python_time > 0: + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo root(): " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python root(): " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + else: + log_print( + "Mojo root(): " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print("Python root(): N/A", log_file) + log_print("Speedup factor: N/A", log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Root Function Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 100 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning root function benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Square root of perfect square + run_benchmark( + "Square root of perfect square", + "9", + 2, + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Square root of non-perfect square + run_benchmark( + "Square root of non-perfect square", + "2", + 2, + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Cube root of perfect cube + run_benchmark( + "Cube root of perfect cube", + "8", + 3, + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Cube root of non-perfect cube + run_benchmark( + "Cube root of non-perfect cube", + "10", + 3, + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Fourth root of perfect power + run_benchmark( + "Fourth root of perfect power", + "16", + 4, + iterations, + log_file, + speedup_factors, + ) + + # Case 6: Fifth root of perfect power + run_benchmark( + "Fifth root of perfect power", + "32", + 5, + iterations, + log_file, + speedup_factors, + ) + + # Case 7: Root of decimal < 1 + run_benchmark( + "Root of decimal < 1", + "0.25", + 2, + iterations, + log_file, + speedup_factors, + ) + + # Case 8: Root of decimal < 1 + run_benchmark( + "Root of small decimal", + "0.0625", + 4, + iterations, + log_file, + speedup_factors, + ) + + # Case 9: High precision decimal + run_benchmark( + "High precision decimal", + "2.7182818284590452353602874", + 2, + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Large integer + run_benchmark( + "Large integer", + "1000000", + 2, + iterations, + log_file, + speedup_factors, + ) + + # Case 11: Large root + run_benchmark( + "Large root", + "10", + 100, + iterations, + log_file, + speedup_factors, + ) + + # Case 12: Odd root of negative number + run_benchmark( + "Odd root of negative number", + "-27", + 3, + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Root of 1 (any root) + run_benchmark( + "Root of 1 (any root)", + "1", + 7, + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Root of 0 + run_benchmark( + "Root of 0", + "0", + 3, + iterations, + log_file, + speedup_factors, + ) + + # Case 15: Custom decimal + run_benchmark( + "Custom decimal", + "123.456", + 2, + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== Root Function Benchmark Summary ===", log_file) + log_print("Benchmarked: 15 different root() cases", log_file) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print("Average speedup: " + String(average_speedup) + "×", log_file) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)), + log_file, + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/mojoproject.toml b/mojoproject.toml index 272c874..401fd32 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -36,6 +36,7 @@ test_arith = "magic run package && magic run mojo test tests/test_arithmetics.mo test_multiply = "magic run package && magic run mojo test tests/test_multiply.mojo && magic run delete_package" test_divide = "magic run package && magic run mojo test tests/test_divide.mojo && magic run delete_package" test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && magic run delete_package" +test_root = "magic run package && magic run mojo test tests/test_root.mojo && magic run delete_package" test_round = "magic run package && magic run mojo test tests/test_round.mojo && magic run delete_package" test_creation = "magic run package && magic run mojo test tests/test_creation.mojo && magic run delete_package" test_from_float = "magic run package && magic run mojo test tests/test_from_float.mojo && magic run delete_package" @@ -53,6 +54,7 @@ b = "clear && magic run bench" bench_multiply = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .. && magic run delete_package" bench_divide = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .. && magic run delete_package" bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo && cd .. && magic run delete_package" +bench_root = "magic run package && cd benches && magic run mojo bench_root.mojo && cd .. && magic run delete_package" bench_round = "magic run package && cd benches && magic run mojo bench_round.mojo && cd .. && magic run delete_package" bench_from_float = "magic run package && cd benches && magic run mojo bench_from_float.mojo && cd .. && magic run delete_package" bench_from_string = "magic run package && cd benches && magic run mojo bench_from_string.mojo && cd .. && magic run delete_package" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 178149e..a1eedca 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -49,7 +49,7 @@ from .comparison import ( not_equal, ) -from .exponential import power, sqrt, exp, ln +from .exponential import power, root, sqrt, exp, ln from .rounding import round diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index 2055bd6..d4b06e3 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -1408,18 +1408,23 @@ struct Decimal( except e: raise Error("Error in `Decimal.round()`: ", e) - fn sqrt(self) raises -> Self: + fn root(self, n: Int) raises -> Self: + """Calculates the n-th root of this Decimal. + See `root()` for more information. """ - Calculates the square root of this Decimal. - - Returns: - The square root of this Decimal. + try: + return decimojo.exponential.root(self, n) + except e: + raise Error("Error in `Decimal.root()`: ", e) - Raises: - Error: If the operation would result in overflow. + fn sqrt(self) raises -> Self: + """Calculates the square root of this Decimal. + See `sqrt()` for more information. """ - - return decimojo.exponential.sqrt(self) + try: + return decimojo.exponential.sqrt(self) + except e: + raise Error("Error in `Decimal.sqrt()`: ", e) # ===------------------------------------------------------------------=== # # Other methods diff --git a/src/decimojo/exponential.mojo b/src/decimojo/exponential.mojo index e5fdcef..f961212 100644 --- a/src/decimojo/exponential.mojo +++ b/src/decimojo/exponential.mojo @@ -31,6 +31,7 @@ import math as builtin_math import testing +import time import decimojo.constants import decimojo.special @@ -155,6 +156,174 @@ fn power(base: Decimal, exponent: Int) raises -> Decimal: return result +fn root(x: Decimal, n: Int) raises -> Decimal: + """Calculates the n-th root of a Decimal value using Newton-Raphson method. + + Args: + x: The Decimal value to compute the n-th root of. + n: The root to compute (must be positive). + + Returns: + A new Decimal containing the n-th root of x. + + Raises: + Error: If x is negative and n is even. + Error: If n is zero or negative. + """ + # var t0 = time.perf_counter_ns() + + # Special cases for n + if n <= 0: + raise Error("Error in `root()`: Cannot compute non-positive root") + if n == 1: + return x + if n == 2: + return sqrt(x) + + # Special cases for x + if x.is_zero(): + return Decimal.ZERO() + if x.is_one(): + return Decimal.ONE() + if x.is_negative(): + if n % 2 == 0: + raise Error( + "Error in `root()`: Cannot compute even root of a negative" + " number" + ) + # For odd roots of negative numbers, compute |x|^(1/n) and negate + return -root(-x, n) + + # Special optimization for very large n + if n > 50: + # For large n, the Newton-Raphson method may converge slowly + # Use logarithm approach directly with higher precision + try: + # Direct calculation: x^n = e^(ln(x)/n) + return exp(ln(x) / Decimal(n)) + except e: + raise Error("Error in `root()`: ", e) + + # Initial guess + # use floating point approach to quickly find a good guess + var x_coef: UInt128 = x.coefficient() + var x_scale = x.scale() + var guess: Decimal + + # For numbers with zero scale (true integers) + if x_scale == 0: + if n <= 8: # 3<=n<=8 + var float_root = pow(Float64(x_coef), 1 / Float64(n)) * Float64( + 10 + ) ** 8 + guess = Decimal.from_uint128( + UInt128(round(float_root)), scale=8, sign=False + ) + elif n <= 16: + var float_root = pow(Float64(x_coef), 1 / Float64(n)) * Float64( + 10 + ) ** 16 + guess = Decimal.from_uint128( + UInt128(round(float_root)), scale=16, sign=False + ) + else: + var float_root = pow(Float64(x_coef), 1 / Float64(n)) * Float64( + 10 + ) ** 26 + guess = Decimal.from_uint128( + UInt128(round(float_root)), scale=26, sign=False + ) + + # Otherwise, use the following formulae: + # let divmod(scale, n) = (x, y) + # so scale = x * n + y = (x + 1) * n + (y - n) + # a^(1/n) / (10^scale)^(1/n) + # = a^(1/n) / (10^(scale/n)) + # = a^(1/n) / (10^((x + 1) * n + y - n) / n)) + # = a^(1/n) / (10^(x+1 + (y-n)/n)) + # = a^(1/n) / 10^(x+1) / 10^((y-n)/n) + # = a^(1/n) / 10^((y/n-1) / 10^(x+1) + else: + var dividend = x_scale // n + var remainder = x_scale % n + var float_root = ( + Float64(x_coef) ** (Float64(1) / Float64(n)) + / Float64(10) ** (Float64(remainder) / Float64(n) - 1) + ) + guess = Decimal.from_uint128( + UInt128(float_root), scale=dividend + 1, sign=False + ) + + # var t_initial_guess = time.perf_counter_ns() + + # Newton-Raphson method for n-th root + # Formula: x_{k+1} = ((n-1)*x_k + a/x_k^(n-1))/n + var prev_guess = Decimal.ZERO() + var n_decimal = Decimal(n) + var n_minus_1 = n - 1 + var n_minus_1_decimal = Decimal(n_minus_1) + var iteration_count = 0 + + # Newton-Raphson iteration + while guess != prev_guess and iteration_count < 100: + prev_guess = guess + var pow_n_minus_1 = power(guess, n_minus_1) + var sum_result = n_minus_1_decimal * guess + x / pow_n_minus_1 + guess = sum_result / n_decimal + iteration_count += 1 + + # var t_newton_raphson = time.perf_counter_ns() + + # If exact root found, remove trailing zeros after the decimal point + # For example, root(27, 3) = 9, not 3.0000000000000 + # Exact root means that the n-th power of coefficient of guess after + # removing trailing zeros is equal to the coefficient of xs + var guess_coef = guess.coefficient() + + # No need to do this if the last digit of the coefficient of guess is not zero + if guess_coef % 10 == 0: + var num_digits_x_ceof = decimojo.utility.number_of_digits(x_coef) + var num_digits_x_root_coef = (num_digits_x_ceof // n) + 1 + var num_digits_guess_coef = decimojo.utility.number_of_digits( + guess_coef + ) + var num_digits_to_decrease = num_digits_guess_coef - num_digits_x_root_coef + + # testing.assert_true( + # num_digits_to_decrease >= 0, + # "root of x has fewer digits than expected", + # ) + for _ in range(num_digits_to_decrease): + if guess_coef % 10 == 0: + guess_coef //= 10 + else: + break + else: + var guess_coef_powered = guess_coef**n + if guess_coef_powered == x_coef: + return Decimal.from_uint128( + guess_coef, + scale=guess.scale() - num_digits_to_decrease, + sign=False, + ) + if guess_coef_powered == x_coef * decimojo.utility.power_of_10[ + DType.uint128 + ](n): + return Decimal.from_uint128( + guess_coef // 10, + scale=guess.scale() - num_digits_to_decrease - 1, + sign=False, + ) + + # print("DEBUG: iteration_count", iteration_count) + # var t_remove_zeros = time.perf_counter_ns() + # print("TIME: initial guess", t_initial_guess - t0) + # print("TIME: Newton-Raphson", t_newton_raphson - t_initial_guess) + # print("TIME: remove zeros", t_remove_zeros - t_newton_raphson) + + return guess + + fn sqrt(x: Decimal) raises -> Decimal: """Computes the square root of a Decimal value using Newton-Raphson method. @@ -176,19 +345,16 @@ fn sqrt(x: Decimal) raises -> Decimal: if x.is_zero(): return Decimal.ZERO() + # Initial guess + # use floating point approach to quickly find a good guess var x_coef: UInt128 = x.coefficient() var x_scale = x.scale() - - # Initial guess - a good guess helps converge faster - # use floating point approach to quickly find a good guess - var guess: Decimal # For numbers with zero scale (true integers) if x_scale == 0: var float_sqrt = builtin_math.sqrt(Float64(x_coef)) - guess = Decimal.from_uint128(UInt128(float_sqrt)) - # print("DEBUG: scale = 0") + guess = Decimal.from_uint128(UInt128(round(float_sqrt))) # For numbers with even scale elif x_scale % 2 == 0: @@ -207,7 +373,7 @@ fn sqrt(x: Decimal) raises -> Decimal: # print("DEBUG: scale is odd") # print("DEBUG: initial guess", guess) - testing.assert_false(guess.is_zero(), "Initial guess should not be zero") + # testing.assert_false(guess.is_zero(), "Initial guess should not be zero") # Newton-Raphson iterations # x_n+1 = (x_n + S/x_n) / 2 @@ -231,11 +397,11 @@ fn sqrt(x: Decimal) raises -> Decimal: # print("DEBUG: iteration_count", iteration_count) - # If exact square root found remove trailing zeros after the decimal point + # If exact square root found, remove trailing zeros after the decimal point # For example, sqrt(81) = 9, not 9.000000 # For example, sqrt(100.0000) = 10.00 not 10.000000 - # Exact square means that the coefficient of guess after removing trailing zeros - # is equal to the coefficient of x + # Exact square means that the squared coefficient of guess after removing + # trailing zeros is equal to the coefficient of x var guess_coef = guess.coefficient() @@ -248,10 +414,10 @@ fn sqrt(x: Decimal) raises -> Decimal: ) var num_digits_to_decrease = num_digits_guess_coef - num_digits_x_sqrt_coef - testing.assert_true( - num_digits_to_decrease >= 0, - "sqrt of x has fewer digits than expected", - ) + # testing.assert_true( + # num_digits_to_decrease >= 0, + # "sqrt of x has fewer digits than expected", + # ) for _ in range(num_digits_to_decrease): if guess_coef % 10 == 0: guess_coef //= 10 @@ -260,18 +426,14 @@ fn sqrt(x: Decimal) raises -> Decimal: else: # print("DEBUG: guess", guess) # print("DEBUG: guess_coef after removing trailing zeros", guess_coef) - if (guess_coef * guess_coef == x_coef) or ( - guess_coef * guess_coef == x_coef * 10 + var guess_coef_squared = guess_coef * guess_coef + if (guess_coef_squared == x_coef) or ( + guess_coef_squared == x_coef * 10 ): - var low = UInt32(guess_coef & 0xFFFFFFFF) - var mid = UInt32((guess_coef >> 32) & 0xFFFFFFFF) - var high = UInt32((guess_coef >> 64) & 0xFFFFFFFF) - return Decimal( - low, - mid, - high, - guess.scale() - num_digits_to_decrease, - False, + return Decimal.from_uint128( + guess_coef, + scale=guess.scale() - num_digits_to_decrease, + sign=False, ) return guess diff --git a/tests/test_root.mojo b/tests/test_root.mojo new file mode 100644 index 0000000..03d905b --- /dev/null +++ b/tests/test_root.mojo @@ -0,0 +1,298 @@ +""" +Comprehensive tests for the root() function in the DeciMojo library. +Tests various cases including basic nth roots, mathematical identities, +and edge cases to ensure proper calculation of x^(1/n). +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode +from decimojo.exponential import root + + +fn test_basic_root_calculations() raises: + """Test basic root calculations for common values.""" + print("Testing basic root calculations...") + + # Test case 1: Square root (n=2) + var num1 = Decimal(9) + var result1 = root(num1, 2) + testing.assert_equal( + String(result1), "3", "√9 should be 3, got " + String(result1) + ) + + # Test case 2: Cube root (n=3) + var num2 = Decimal(8) + var result2 = root(num2, 3) + testing.assert_equal( + String(result2), "2", "∛8 should be 2, got " + String(result2) + ) + + # Test case 3: Fourth root (n=4) + var num3 = Decimal(16) + var result3 = root(num3, 4) + testing.assert_equal( + String(result3), "2", "∜16 should be 2, got " + String(result3) + ) + + # Test case 4: Square root of non-perfect square + var num4 = Decimal(2) + var result4 = root(num4, 2) + testing.assert_true( + String(result4).startswith("1.4142135623730950488"), + "√2 should be approximately 1.414..., got " + String(result4), + ) + + # Test case 5: Cube root of non-perfect cube + var num5 = Decimal(10) + var result5 = root(num5, 3) + testing.assert_true( + String(result5).startswith("2.154434690031883721"), + "∛10 should be approximately 2.154..., got " + String(result5), + ) + + print("✓ Basic root calculations tests passed!") + + +fn test_fractional_inputs() raises: + """Test root calculations with fractional inputs.""" + print("Testing root calculations with fractional inputs...") + + # Test case 1: Square root of decimal + var num1 = Decimal("0.25") + var result1 = root(num1, 2) + testing.assert_equal( + String(result1), "0.5", "√0.25 should be 0.5, got " + String(result1) + ) + + # Test case 2: Cube root of decimal + var num2 = Decimal("0.125") + var result2 = root(num2, 3) + testing.assert_equal( + String(result2), "0.5", "∛0.125 should be 0.5, got " + String(result2) + ) + + # Test case 3: High precision decimal input + var num3 = Decimal("1.44") + var result3 = root(num3, 2) + testing.assert_true( + String(result3).startswith("1.2"), + "√1.44 should be 1.2, got " + String(result3), + ) + + # Test case 4: Decimal input with non-integer result + var num4 = Decimal("0.5") + var result4 = root(num4, 2) + testing.assert_true( + String(result4).startswith("0.7071067811865475"), + "√0.5 should be approximately 0.7071..., got " + String(result4), + ) + + print("✓ Fractional input tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for the root function.""" + print("Testing root edge cases...") + + # Test case 1: Root of 0 + var zero = Decimal(0) + var result1 = root(zero, 2) + testing.assert_equal( + String(result1), "0", "√0 should be 0, got " + String(result1) + ) + + # Test case 2: Root of 1 + var one = Decimal(1) + var result2 = root(one, 100) # Any root of 1 is 1 + testing.assert_equal( + String(result2), + "1", + "100th root of 1 should be 1, got " + String(result2), + ) + + # Test case 3: 1st root of any number is the number itself + var num3 = Decimal("123.456") + var result3 = root(num3, 1) + testing.assert_equal( + String(result3), + "123.456", + "1st root of 123.456 should be 123.456, got " + String(result3), + ) + + # Test case 4: Very large root of a number + var num4 = Decimal(10) + var result4 = root(num4, 100) # 100th root of 10 + testing.assert_true( + String(result4).startswith("1.02329299228075413096627517"), + "100th root of 10 should be approximately" + " 1.02329299228075413096627517..., got " + + String(result4), + ) + + print("✓ Edge cases tests passed!") + + +fn test_error_conditions() raises: + """Test error conditions for the root function.""" + print("Testing root error conditions...") + + # Test case 1: 0th root (should raise error) + var num1 = Decimal(10) + var exception_caught = False + try: + var _result = root(num1, 0) + testing.assert_equal( + True, False, "0th root should raise error but didn't" + ) + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + # Test case 2: Negative root (should raise error) + var num2 = Decimal(10) + exception_caught = False + try: + var _result = root(num2, -2) + testing.assert_equal( + True, False, "Negative root should raise error but didn't" + ) + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + # Test case 3: Negative number with even root (should raise error) + var num3 = Decimal(-4) + exception_caught = False + try: + var _result = root(num3, 2) + testing.assert_equal( + True, + False, + "Even root of negative number should raise error but didn't", + ) + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + # Test case 4: Negative number with odd root (should work) + var num4 = Decimal(-8) + var result4 = root(num4, 3) + testing.assert_equal( + String(result4), "-2", "∛-8 should be -2, got " + String(result4) + ) + + print("✓ Error conditions tests passed!") + + +fn test_precision() raises: + """Test precision of root calculations.""" + print("Testing precision of root calculations...") + + # Test case 1: High precision square root + var num1 = Decimal(2) + var result1 = root(num1, 2) + testing.assert_true( + String(result1).startswith("1.414213562373095048801688724"), + "√2 with high precision should be accurate to at least 25 digits", + ) + + # Test case 2: High precision cube root + var num2 = Decimal(2) + var result2 = root(num2, 3) + testing.assert_true( + String(result2).startswith("1.25992104989487316476721060"), + "∛2 with high precision should be accurate to at least 25 digits", + ) + + # Test case 3: Compare with known precise values + var num3 = Decimal(5) + var result3 = root(num3, 2) + testing.assert_true( + String(result3).startswith("2.236067977499789696"), + "√5 should match known value starting with 2.236067977499789696...", + ) + + print("✓ Precision tests passed!") + + +fn test_mathematical_identities() raises: + """Test mathematical identities involving roots.""" + print("Testing mathematical identities involving roots...") + + # Test case 1: (√x)^2 = x + var x1 = Decimal(7) + var sqrt_x1 = root(x1, 2) + var squared_back = sqrt_x1 * sqrt_x1 + testing.assert_true( + abs(squared_back - x1) < Decimal("0.0000000001"), + "(√x)^2 should equal x within tolerance", + ) + + # Test case 2: ∛(x^3) = x + var x2 = Decimal(3) + var cubed = x2 * x2 * x2 + var root_back = root(cubed, 3) + testing.assert_true( + abs(root_back - x2) < Decimal("0.0000000001"), + "∛(x^3) should equal x within tolerance", + ) + + # Test case 3: √(a*b) = √a * √b + var a = Decimal(4) + var b = Decimal(9) + var sqrt_product = root(a * b, 2) + var product_sqrts = root(a, 2) * root(b, 2) + testing.assert_true( + abs(sqrt_product - product_sqrts) < Decimal("0.0000000001"), + "√(a*b) should equal √a * √b within tolerance", + ) + + # Test case 4: Consistency with power function: x^(1/n) = nth root of x + var x4 = Decimal(5) + var n = 3 # Cube root + var power_result = x4 ** (Decimal(1) / Decimal(n)) + var root_result = root(x4, n) + testing.assert_true( + abs(power_result - root_result) < Decimal("0.0000000001"), + "x^(1/n) should equal nth root of x within tolerance", + ) + + print("✓ Mathematical identities tests passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Root Function Tests") + print("=========================================") + + run_test_with_error_handling( + test_basic_root_calculations, "Basic root calculations test" + ) + run_test_with_error_handling( + test_fractional_inputs, "Fractional inputs test" + ) + run_test_with_error_handling(test_edge_cases, "Edge cases test") + run_test_with_error_handling(test_error_conditions, "Error conditions test") + run_test_with_error_handling(test_precision, "Precision test") + run_test_with_error_handling( + test_mathematical_identities, "Mathematical identities test" + ) + + print("All root function tests passed!")