diff --git a/benches/bench.mojo b/benches/bench.mojo index 1acb8a0..73b9acf 100644 --- a/benches/bench.mojo +++ b/benches/bench.mojo @@ -8,6 +8,8 @@ from bench_from_string import main as bench_from_string from bench_round import main as bench_round from bench_comparison import main as bench_comparison from bench_exp import main as bench_exp +from bench_ln import main as bench_ln +from bench_power import main as bench_power fn main() raises: @@ -21,3 +23,5 @@ fn main() raises: bench_round() bench_comparison() bench_exp() + bench_ln() + bench_power() diff --git a/benches/bench_multiply.mojo b/benches/bench_multiply.mojo index 27531c9..f6f2ef9 100644 --- a/benches/bench_multiply.mojo +++ b/benches/bench_multiply.mojo @@ -299,7 +299,7 @@ fn main() raises: # Case 11: Decimal multiplication with many digits after the decimal point var case11_a_mojo = Decimal.E() - var case11_b_mojo = Decimal.E05() + var case11_b_mojo = dm.constants.E0D5() var case11_a_py = pydecimal.Decimal("1").exp() var case11_b_py = pydecimal.Decimal("0.5").exp() run_benchmark( diff --git a/benches/bench_power.mojo b/benches/bench_power.mojo new file mode 100644 index 0000000..f8547b1 --- /dev/null +++ b/benches/bench_power.mojo @@ -0,0 +1,390 @@ +""" +Comprehensive benchmarks for Decimal power function. +Compares performance against Python's decimal module. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_power_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark( + name: String, + base_value: String, + exponent_value: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal power with Python Decimal power. + + Args: + name: Name of the benchmark case. + base_value: String representation of the base Decimal. + exponent_value: String representation of the exponent Decimal. + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Base: " + base_value, log_file) + log_print("Exponent: " + exponent_value, log_file) + + # Set up Mojo and Python values + var mojo_base = Decimal(base_value) + var mojo_exponent = Decimal(exponent_value) + var pydecimal = Python.import_module("decimal") + var py_base = pydecimal.Decimal(base_value) + var py_exponent = pydecimal.Decimal(exponent_value) + + # Execute the operations once to verify correctness + var mojo_result = dm.exponential.power(mojo_base, mojo_exponent) + var py_result = py_base**py_exponent + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = dm.exponential.power(mojo_base, mojo_exponent) + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_base**py_exponent + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo power(): " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python power(): " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Power Function Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 100 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning power function benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Integer base and exponent + run_benchmark( + "Integer base and exponent", + "2", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Decimal base and integer exponent + run_benchmark( + "Decimal base and integer exponent", + "2.5", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Integer base and decimal exponent + run_benchmark( + "Integer base and decimal exponent", + "9", + "0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Decimal base and exponent + run_benchmark( + "Decimal base and exponent", + "2.0", + "1.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Negative exponent + run_benchmark( + "Negative exponent", + "4", + "-0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 6: Large base and exponent + run_benchmark( + "Large base and exponent", + "12345", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 7: Small base and exponent + run_benchmark( + "Small base and exponent", + "0.5", + "0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 8: High precision base and exponent + run_benchmark( + "High precision base and exponent", + "1.234567890123456789", + "2.345678901234567890", + iterations, + log_file, + speedup_factors, + ) + + # Case 9: Base close to 1 + run_benchmark( + "Base close to 1", + "1.000000001", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Exponent close to 1 + run_benchmark( + "Exponent close to 1", + "2", + "1.000000001", + iterations, + log_file, + speedup_factors, + ) + + # Case 11: Zero base and positive exponent + run_benchmark( + "Zero base and positive exponent", + "0", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 12: One base and any exponent + run_benchmark( + "One base and any exponent", + "1", + "3.14", + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Large base and small exponent + run_benchmark( + "Large base and small exponent", + "1000000", + "0.1", + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Small base and large exponent + run_benchmark( + "Small base and large exponent", + "0.00001", + "10", + iterations, + log_file, + speedup_factors, + ) + + # Case 15: Base greater than 1 and negative exponent + run_benchmark( + "Base greater than 1 and negative exponent", + "2", + "-3", + iterations, + log_file, + speedup_factors, + ) + + # Case 16: Base less than 1 and negative exponent + run_benchmark( + "Base less than 1 and negative exponent", + "0.5", + "-2", + iterations, + log_file, + speedup_factors, + ) + + # Case 17: Base with many digits and exponent with few digits + run_benchmark( + "Base with many digits and exponent with few digits", + "1.234567890123456789", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 18: Base with few digits and exponent with many digits + run_benchmark( + "Base with few digits and exponent with many digits", + "2", + "1.234567890123456789", + iterations, + log_file, + speedup_factors, + ) + + # Case 19: Base and exponent with alternating digits + run_benchmark( + "Base and exponent with alternating digits", + "1.01010101", + "2.02020202", + iterations, + log_file, + speedup_factors, + ) + + # Case 20: Base and exponent with specific pattern + run_benchmark( + "Base and exponent with specific pattern", + "3.14159", + "2.71828", + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== Power Function Benchmark Summary ===", log_file) + log_print("Benchmarked: 20 different power() cases", log_file) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print("Average speedup: " + String(average_speedup) + "×", log_file) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)), + log_file, + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/mojoproject.toml b/mojoproject.toml index ef3a35a..272c874 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -45,6 +45,7 @@ test_comparison = "magic run package && magic run mojo test tests/test_compariso test_factorial = "magic run package && magic run mojo test tests/test_factorial.mojo && magic run delete_package" test_exp = "magic run package && magic run mojo test tests/test_exp.mojo && magic run delete_package" test_ln = "magic run package && magic run mojo test tests/test_ln.mojo && magic run delete_package" +test_power = "magic run package && magic run mojo test tests/test_power.mojo && magic run delete_package" # benches bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" @@ -58,6 +59,7 @@ bench_from_string = "magic run package && cd benches && magic run mojo bench_fro bench_comparison = "magic run package && cd benches && magic run mojo bench_comparison.mojo && cd .. && magic run delete_package" bench_exp = "magic run package && cd benches && magic run mojo bench_exp.mojo && cd .. && magic run delete_package" bench_ln = "magic run package && cd benches && magic run mojo bench_ln.mojo && cd .. && magic run delete_package" +bench_power = "magic run package && cd benches && magic run mojo bench_power.mojo && cd .. && magic run delete_package" # before commit final = "magic run test && magic run bench" diff --git a/src/decimojo/arithmetics.mojo b/src/decimojo/arithmetics.mojo index fa2ad39..a2ee83a 100644 --- a/src/decimojo/arithmetics.mojo +++ b/src/decimojo/arithmetics.mojo @@ -495,8 +495,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # SPECIAL CASE: Both operands are true integers if x1_scale == 0 and x2_scale == 0: - print("DEBUG: Both operands are true integers") - print("DEBUG: combined_num_bits: ", combined_num_bits) + # print("DEBUG: Both operands are true integers") + # print("DEBUG: combined_num_bits: ", combined_num_bits) # Small integers, use UInt64 multiplication if combined_num_bits <= 64: var prod: UInt64 = UInt64(x1_coef) * UInt64(x2_coef) diff --git a/src/decimojo/constants.mojo b/src/decimojo/constants.mojo index 738a98f..a0ccb0f 100644 --- a/src/decimojo/constants.mojo +++ b/src/decimojo/constants.mojo @@ -25,12 +25,14 @@ # ===----------------------------------------------------------------------=== # # -# Integer constants +# Integer and decimal constants # The prefix "M" stands for a decimal (money) value. # This is a convention in C. # # ===----------------------------------------------------------------------=== # +# Integer constants + @always_inline fn M0() -> Decimal: @@ -98,6 +100,21 @@ fn M10() -> Decimal: return Decimal(0xA, 0x0, 0x0, 0x0) +# Decimal constants + + +@always_inline +fn M0D5() -> Decimal: + """Returns 0.5 as a Decimal.""" + return Decimal(5, 0, 0, 0x10000) + + +@always_inline +fn M0D25() -> Decimal: + """Returns 0.25 as a Decimal.""" + return Decimal(25, 0, 0, 0x20000) + + # ===----------------------------------------------------------------------=== # # # Inverse constants diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index f1a8f59..584931a 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -1059,25 +1059,8 @@ struct Decimal( fn __sub__(self, other: Decimal) raises -> Self: """ Subtracts the other Decimal from self and returns a new Decimal. - - Args: - other: The Decimal to subtract from this Decimal. - - Returns: - A new Decimal containing the difference - - Notes: - This method is implemented using the existing `__add__()` and `__neg__()` methods. - - Examples: - ```console - var a = Decimal("10.5") - var b = Decimal("3.2") - var result = a - b # Returns 7.3 - ``` - . + See `subtract()` for more information. """ - try: return decimojo.arithmetics.subtract(self, other) except e: @@ -1127,29 +1110,25 @@ struct Decimal( return decimojo.arithmetics.true_divide(Decimal(other), self) fn __pow__(self, exponent: Decimal) raises -> Self: + """Raises self to the power of exponent and returns a new Decimal. + See `power()` for more information. """ - Raises self to the power of exponent and returns a new Decimal. - - Currently supports integer exponents only. - - Args: - exponent: The power to raise self to. - It must be an integer or effectively an integer (e.g., 2.0). - - Returns: - A new Decimal containing the result of self^exponent - - Raises: - Error: If exponent is not an integer or if the operation would overflow. - """ - - return decimal.power(self, exponent) + try: + return decimal.power(self, exponent) + except e: + raise Error("Error in `__pow__()`: ", e) fn __pow__(self, exponent: Int) raises -> Self: - return decimal.power(self, exponent) + try: + return decimal.power(self, exponent) + except e: + raise Error("Error in `__pow__()`: ", e) fn __pow__(self, exponent: Float64) raises -> Self: - return decimal.power(self, Decimal(exponent)) + try: + return decimal.power(self, Decimal(exponent)) + except e: + raise Error("Error in `__pow__()`: ", e) # ===------------------------------------------------------------------=== # # Basic binary comparison operation dunders diff --git a/src/decimojo/exponential.mojo b/src/decimojo/exponential.mojo index 4d97c2b..e5fdcef 100644 --- a/src/decimojo/exponential.mojo +++ b/src/decimojo/exponential.mojo @@ -38,36 +38,66 @@ import decimojo.utility fn power(base: Decimal, exponent: Decimal) raises -> Decimal: - """ - Raises base to the power of exponent and returns a new Decimal. + """Raises a Decimal base to an arbitrary Decimal exponent power. - Currently supports integer exponents only. + This function handles both integer and non-integer exponents using the + identity x^y = e^(y * ln(x)). Args: - base: The base value. - exponent: The power to raise base to. - It must be an integer or effectively an integer (e.g., 2.0). + base: The base Decimal value (must be positive). + exponent: The exponent Decimal value (can be any value). Returns: - A new Decimal containing the result of base^exponent + A new Decimal containing the result of base^exponent. Raises: - Error: If exponent is not an integer or if the operation would overflow. - Error: If zero is raised to a negative power. + Error: If the base is negative or the exponent is negative and not an integer. + Error: If an error occurs in calling the power() function with an integer exponent. + Error: If an error occurs in calling the sqrt() function with a Decimal exponent. + Error: If an error occurs in calling the ln() function with a Decimal base. + Error: If an error occurs in calling the exp() function with a Decimal exponent. """ - # Check if exponent is an integer - if not exponent.is_integer(): - raise Error("Power operation is only supported for integer exponents") - # Convert exponent to integer - var exp_value = Int(exponent) + # CASE: If the exponent is integer + if exponent.is_integer(): + try: + return power(base, Int(exponent)) + except e: + raise Error("Error in `power()` with Decimal exponent: ", e) + + # CASE: For negative bases, only integer exponents are supported + if base.is_negative(): + raise Error( + "Negative base with non-integer exponent results in a complex" + " number" + ) - return power(base, exp_value) + # CASE: If the exponent is simple fractions + # 0.5 + if exponent == decimojo.constants.M0D5(): + try: + return sqrt(base) + except e: + raise Error("Error in `power()` with Decimal exponent: ", e) + # -0.5 + if exponent == Decimal(5, 0, 0, 0x80010000): + try: + return Decimal.ONE() / sqrt(base) + except e: + raise Error("Error in `power()` with Decimal exponent: ", e) + + # GENERAL CASE + # Use the identity x^y = e^(y * ln(x)) + try: + var ln_base = ln(base) + var product = exponent * ln_base + return exp(product) + except e: + raise Error("Error in `power()` with Decimal exponent: ", e) fn power(base: Decimal, exponent: Int) raises -> Decimal: - """ - Convenience method to raise base to an integer power. + """Raises a Decimal base to an integer power. Args: base: The base value. @@ -126,8 +156,7 @@ fn power(base: Decimal, exponent: Int) raises -> Decimal: fn sqrt(x: Decimal) raises -> Decimal: - """ - Computes the square root of a Decimal value using Newton-Raphson method. + """Computes the square root of a Decimal value using Newton-Raphson method. Args: x: The Decimal value to compute the square root of. @@ -249,8 +278,7 @@ fn sqrt(x: Decimal) raises -> Decimal: fn exp(x: Decimal) raises -> Decimal: - """ - Calculates e^x for any Decimal value using optimized range reduction. + """Calculates e^x for any Decimal value using optimized range reduction. x should be no greater than 66 to avoid overflow. Args: @@ -298,8 +326,8 @@ fn exp(x: Decimal) raises -> Decimal: return decimojo.constants.E() elif x_int < 1: - var M0D5 = Decimal(5, 0, 0, 1 << 16) # 0.5 - var M0D25 = Decimal(25, 0, 0, 2 << 16) # 0.25 + var M0D5 = decimojo.constants.M0D5() + var M0D25 = decimojo.constants.M0D25() if x < M0D25: # 0 < x < 0.25 return exp_series(x) @@ -394,8 +422,7 @@ fn exp(x: Decimal) raises -> Decimal: fn exp_series(x: Decimal) raises -> Decimal: - """ - Calculates e^x using Taylor series expansion. + """Calculates e^x using Taylor series expansion. Do not use this function for values larger than 1, but `exp()` instead. Args: @@ -445,8 +472,7 @@ fn exp_series(x: Decimal) raises -> Decimal: fn ln(x: Decimal) raises -> Decimal: - """ - Calculates the natural logarithm (ln) of a Decimal value. + """Calculates the natural logarithm (ln) of a Decimal value. Args: x: The Decimal value to compute the natural logarithm of. @@ -651,8 +677,7 @@ fn ln(x: Decimal) raises -> Decimal: fn ln_series(z: Decimal) raises -> Decimal: - """ - Calculates ln(1+z) using Taylor series expansion at 1. + """Calculates ln(1+z) using Taylor series expansion at 1. For best accuracy, |z| should be small (< 0.5). Args: diff --git a/tests/test_arithmetics.mojo b/tests/test_arithmetics.mojo index 4c61865..993f970 100644 --- a/tests/test_arithmetics.mojo +++ b/tests/test_arithmetics.mojo @@ -466,197 +466,6 @@ fn test_subtract() raises: print("Decimal subtraction tests passed!") -fn test_power_integer_exponents() raises: - print("------------------------------------------------------") - print("Testing power with integer exponents...") - - # Test case 1: Base cases: x^0 = 1 for any x except 0 - var a1 = Decimal("2.5") - var result1 = a1**0 - testing.assert_equal( - String(result1), "1", "Any number to power 0 should be 1" - ) - - # Test case 2: 0^n = 0 for n > 0 - var a2 = Decimal("0") - var result2 = a2**5 - testing.assert_equal( - String(result2), "0", "0 to any positive power should be 0" - ) - - # Test case 3: x^1 = x - var a3 = Decimal("3.14159") - var result3 = a3**1 - testing.assert_equal(String(result3), "3.14159", "x^1 should be x") - - # Test case 4: Positive integer powers - var a4 = Decimal("2") - var result4 = a4**3 - testing.assert_equal(String(result4), "8", "2^3 should be 8") - - # Test case 5: Test with scale - var a5 = Decimal("1.5") - var result5 = a5**2 - testing.assert_equal(String(result5), "2.25", "1.5^2 should be 2.25") - - # Test case 6: Larger powers - var a6 = Decimal("2") - var result6 = a6**10 - testing.assert_equal(String(result6), "1024", "2^10 should be 1024") - - # Test case 7: Negative base, even power - var a7 = Decimal("-3") - var result7 = a7**2 - testing.assert_equal(String(result7), "9", "(-3)^2 should be 9") - - # Test case 8: Negative base, odd power - var a8 = Decimal("-3") - var result8 = a8**3 - testing.assert_equal(String(result8), "-27", "(-3)^3 should be -27") - - # Test case 9: Decimal base, positive power - var a9 = Decimal("0.1") - var result9 = a9**3 - testing.assert_equal(String(result9), "0.001", "0.1^3 should be 0.001") - - # Test case 10: Large number to small power - var a10 = Decimal("1000") - var result10 = a10**2 - testing.assert_equal( - String(result10), "1000000", "1000^2 should be 1000000" - ) - - print("Integer exponent tests passed!") - - -fn test_power_negative_exponents() raises: - print("------------------------------------------------------") - print("Testing power with negative integer exponents...") - - # Test case 1: Basic negative exponent - var a1 = Decimal("2") - var result1 = a1 ** (-2) - testing.assert_equal(String(result1), "0.25", "2^(-2) should be 0.25") - - # Test case 2: Larger negative exponent - var a2 = Decimal("10") - var result2 = a2 ** (-3) - testing.assert_equal(String(result2), "0.001", "10^(-3) should be 0.001") - - # Test case 3: Negative base, even negative power - var a3 = Decimal("-2") - var result3 = a3 ** (-2) - testing.assert_equal(String(result3), "0.25", "(-2)^(-2) should be 0.25") - - # Test case 4: Negative base, odd negative power - var a4 = Decimal("-2") - var result4 = a4 ** (-3) - testing.assert_equal( - String(result4), "-0.125", "(-2)^(-3) should be -0.125" - ) - - # Test case 5: Decimal base, negative power - var a5 = Decimal("0.5") - var result5 = a5 ** (-2) - testing.assert_equal(String(result5), "4", "0.5^(-2) should be 4") - - # Test case 6: 1^(-n) = 1 - var a6 = Decimal("1") - var result6 = a6 ** (-5) - testing.assert_equal(String(result6), "1", "1^(-5) should be 1") - - print("Negative exponent tests passed!") - - -fn test_power_special_cases() raises: - print("------------------------------------------------------") - print("Testing power function special cases...") - - # Test case 1: 0^0 (typically defined as 1) - var a1 = Decimal("0") - try: - var result1 = a1**0 - testing.assert_equal(String(result1), "1", "0^0 should be defined as 1") - except: - print("0^0 raises an exception (mathematically undefined)") - - # Test case 2: 0^(-n) (mathematically undefined) - var a2 = Decimal("0") - try: - var result2 = a2 ** (-2) - print("WARNING: 0^(-2) didn't raise an exception, got", result2) - except: - print("0^(-2) correctly raises an exception") - - # Test case 3: 1^n = 1 for any n - var a3 = Decimal("1") - var result3a = a3**100 - var result3b = a3 ** (-100) - testing.assert_equal(String(result3a), "1", "1^100 should be 1") - testing.assert_equal(String(result3b), "1", "1^(-100) should be 1") - - # Test case 4: High precision result with rounding - # TODO: Implement __gt__ - # var a4 = Decimal("1.1") - # var result4 = a4**30 - # testing.assert_true( - # result4 > Decimal("17.4") and result4 < Decimal("17.5"), - # "1.1^30 should be approximately 17.449", - # ) - - print("Special cases tests passed!") - - -fn test_power_decimal_exponents() raises: - print("------------------------------------------------------") - print("Testing power with decimal exponents...") - - # Try a few basic decimal exponents if supported - try: - var a1 = Decimal("4") - var e1 = Decimal("0.5") # Square root - var result1 = a1**e1 - testing.assert_equal(String(result1), "2", "4^0.5 should be 2") - - var a2 = Decimal("8") - var e2 = Decimal("1.5") # Cube root of square - var result2 = a2**e2 - testing.assert_equal( - String(result2)[:4], "22.6", "8^1.5 should be approximately 22.6" - ) - except: - print("Decimal exponents not supported in this implementation") - - print("Decimal exponent tests passed!") - - -fn test_power_precision() raises: - print("------------------------------------------------------") - print("Testing power precision...") - - # These tests assume we have overloaded the ** operator - # and we have a way to control precision similar to pow() - try: - # Test with precision control - var a1 = Decimal("1.5") - var result1 = a1**2 - # Test equality including precision - testing.assert_equal( - String(result1), "2.25", "1.5^2 should be exactly 2.25" - ) - - # Check scale - testing.assert_equal( - result1.scale(), - 2, - "Result should maintain precision of 2 decimal places", - ) - except: - print("Precision parameters not supported with ** operator") - - print("Precision tests passed!") - - fn test_extreme_cases() raises: print("------------------------------------------------------") print("Testing extreme cases...") @@ -725,22 +534,6 @@ fn main() raises: # Run subtraction tests test_subtract() - # Run power tests with integer exponents - test_power_integer_exponents() - - # Run power tests with negative exponents - - test_power_negative_exponents() - - # Run power tests for special cases - test_power_special_cases() - - # Run power tests with decimal exponents - test_power_decimal_exponents() - - # Run power precision tests - test_power_precision() - # Run extreme cases tests test_extreme_cases() diff --git a/tests/test_power.mojo b/tests/test_power.mojo new file mode 100644 index 0000000..522078c --- /dev/null +++ b/tests/test_power.mojo @@ -0,0 +1,179 @@ +""" +Comprehensive tests for the power function of the Decimal type. +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode +from decimojo.exponential import power + + +fn test_integer_powers() raises: + """Test raising a Decimal to an integer power.""" + print("Testing integer powers...") + + # Test case 1: Positive base, positive exponent + var base1 = Decimal("2") + var exponent1 = 3 + var result1 = power(base1, exponent1) + testing.assert_equal( + String(result1), "8", "2^3 should be 8, got " + String(result1) + ) + + # Test case 2: Positive base, zero exponent + var base2 = Decimal("5") + var exponent2 = 0 + var result2 = power(base2, exponent2) + testing.assert_equal( + String(result2), "1", "5^0 should be 1, got " + String(result2) + ) + + # Test case 3: Positive base, negative exponent + var base3 = Decimal("2") + var exponent3 = -2 + var result3 = power(base3, exponent3) + testing.assert_equal( + String(result3), "0.25", "2^-2 should be 0.25, got " + String(result3) + ) + + # Test case 4: Decimal base, positive exponent + var base4 = Decimal("2.5") + var exponent4 = 2 + var result4 = power(base4, exponent4) + testing.assert_equal( + String(result4), "6.25", "2.5^2 should be 6.25, got " + String(result4) + ) + + # Test case 5: Decimal base, negative exponent + var base5 = Decimal("0.5") + var exponent5 = -1 + var result5 = power(base5, exponent5) + testing.assert_equal( + String(result5), "2", "0.5^-1 should be 2, got " + String(result5) + ) + + print("✓ Integer powers tests passed!") + + +fn test_decimal_powers() raises: + """Test raising a Decimal to a Decimal power.""" + print("Testing decimal powers...") + + # Test case 1: Positive base, simple fractional exponent (0.5) + var base1 = Decimal("9") + var exponent1 = Decimal("0.5") + var result1 = power(base1, exponent1) + testing.assert_equal( + String(result1), "3", "9^0.5 should be 3, got " + String(result1) + ) + + # Test case 2: Positive base, more complex fractional exponent + var base2 = Decimal("2") + var exponent2 = Decimal("1.5") + var result2 = power(base2, exponent2) + testing.assert_true( + String(result2).startswith("2.828427124746190097603377448"), + "2^1.5 should be approximately 2.828..., got " + String(result2), + ) + + # Test case 3: Decimal base, decimal exponent + var base3 = Decimal("2.5") + var exponent3 = Decimal("0.5") + var result3 = power(base3, exponent3) + testing.assert_true( + String(result3).startswith("1.5811388300841896659994467722"), + "2.5^0.5 should be approximately 1.5811388300841896659994467722...," + " got " + + String(result3), + ) + + # Test case 4: Base > 1, exponent < 0 + var base4 = Decimal("4") + var exponent4 = Decimal("-0.5") + var result4 = power(base4, exponent4) + testing.assert_equal( + String(result4), "0.5", "4^-0.5 should be 0.5, got " + String(result4) + ) + + print("✓ Decimal powers tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for the power function.""" + print("Testing power edge cases...") + + # Test case 1: Zero base, positive exponent + var base1 = Decimal("0") + var exponent1 = Decimal("2") + var result1 = power(base1, exponent1) + testing.assert_equal( + String(result1), "0", "0^2 should be 0, got " + String(result1) + ) + + # Test case 2: Zero base, negative exponent (should raise error) + var base2 = Decimal("0") + var exponent2 = Decimal("-2") + var exception_caught = False + try: + var _result = power(base2, exponent2) + testing.assert_equal( + True, False, "0^-2 should raise an exception, but it didn't" + ) + except: + exception_caught = True + testing.assert_equal( + exception_caught, True, "0^-2 should raise an exception" + ) + + # Test case 3: Negative base, integer exponent + var base3 = Decimal("-2") + var exponent3 = Decimal("3") + var result3 = power(base3, exponent3) + testing.assert_equal( + String(result3), "-8", "(-2)^3 should be -8, got " + String(result3) + ) + + # Test case 4: Negative base, non-integer exponent (should raise error) + var base4 = Decimal("-2") + var exponent4 = Decimal("0.5") + exception_caught = False + try: + var _result2 = power(base4, exponent4) + testing.assert_equal( + True, False, "(-2)^0.5 should raise an exception, but it didn't" + ) + except: + exception_caught = True + testing.assert_equal( + exception_caught, True, "(-2)^0.5 should raise an exception" + ) + + print("✓ Edge cases tests passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Decimal Power Function Tests") + print("=========================================") + + run_test_with_error_handling(test_integer_powers, "Integer powers test") + run_test_with_error_handling(test_decimal_powers, "Decimal powers test") + run_test_with_error_handling(test_edge_cases, "Edge cases test") + + print("All power function tests passed!")