diff --git a/benches/bench_exp.mojo b/benches/bench_exp.mojo new file mode 100644 index 0000000..195fbac --- /dev/null +++ b/benches/bench_exp.mojo @@ -0,0 +1,366 @@ +""" +Comprehensive benchmarks for Decimal exponential function (exp). +Compares performance against Python's decimal module with 20 diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_exp_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark( + name: String, + input_value: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal exp with Python Decimal exp. + + Args: + name: Name of the benchmark case. + input_value: String representation of value for exp(x). + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Input value: " + input_value, log_file) + + # Set up Mojo and Python values + var mojo_decimal = Decimal(input_value) + var pydecimal = Python.import_module("decimal") + var py_decimal = pydecimal.Decimal(input_value) + var py_math = Python.import_module("math") + + # Execute the operations once to verify correctness + var mojo_result = dm.exponential.exp(mojo_decimal) + var py_result = py_decimal.exp() + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = dm.exponential.exp(mojo_decimal) + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_decimal.exp() + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo exp(): " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python exp(): " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Exponential Function (exp) Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 100 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning exponential function benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: exp(0) = 1 + run_benchmark( + "exp(0) = 1", + "0", + iterations, + log_file, + speedup_factors, + ) + + # Case 2: exp(1) ≈ e + run_benchmark( + "exp(1) ≈ e", + "1", + iterations, + log_file, + speedup_factors, + ) + + # Case 3: exp(2) ≈ 7.389... + run_benchmark( + "exp(2)", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 4: exp(-1) = 1/e + run_benchmark( + "exp(-1) = 1/e", + "-1", + iterations, + log_file, + speedup_factors, + ) + + # Case 5: exp(0.5) ≈ sqrt(e) + run_benchmark( + "exp(0.5) ≈ sqrt(e)", + "0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 6: exp(-0.5) ≈ 1/sqrt(e) + run_benchmark( + "exp(-0.5) ≈ 1/sqrt(e)", + "-0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 7: exp with small positive value + run_benchmark( + "Small positive value", + "0.0001", + iterations, + log_file, + speedup_factors, + ) + + # Case 8: exp with very small positive value + run_benchmark( + "Very small positive value", + "0.000000001", + iterations, + log_file, + speedup_factors, + ) + + # Case 9: exp with small negative value + run_benchmark( + "Small negative value", + "-0.0001", + iterations, + log_file, + speedup_factors, + ) + + # Case 10: exp with very small negative value + run_benchmark( + "Very small negative value", + "-0.000000001", + iterations, + log_file, + speedup_factors, + ) + + # Case 11: exp with moderate value (e^3) + run_benchmark( + "Moderate value (e^3)", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 12: exp with moderate negative value (e^-3) + run_benchmark( + "Moderate negative value (e^-3)", + "-3", + iterations, + log_file, + speedup_factors, + ) + + # Case 13: exp with large value (e^10) + run_benchmark( + "Large value (e^10)", + "10", + iterations, + log_file, + speedup_factors, + ) + + # Case 14: exp with large negative value (e^-10) + run_benchmark( + "Large negative value (e^-10)", + "-10", + iterations, + log_file, + speedup_factors, + ) + + # Case 15: exp with Pi + run_benchmark( + "exp(π)", + "3.14159265358979323846", + iterations, + log_file, + speedup_factors, + ) + + # Case 16: exp with high precision input + run_benchmark( + "High precision input", + "1.234567890123456789", + iterations, + log_file, + speedup_factors, + ) + + # Case 17: exp with fractional value + run_benchmark( + "Fractional value (e^1.5)", + "1.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 18: exp with negative fractional value + run_benchmark( + "Negative fractional value (e^-1.5)", + "-1.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 19: exp with approximate e value + run_benchmark( + "Approximate e value", + "2.718281828459045", + iterations, + log_file, + speedup_factors, + ) + + # Case 20: exp with larger value (e^15) + run_benchmark( + "Larger value (e^15)", + "15", + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== Exponential Function Benchmark Summary ===", log_file) + log_print("Benchmarked: 20 different exp() cases", log_file) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print("Average speedup: " + String(average_speedup) + "×", log_file) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)), + log_file, + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/benches/bench_multiply.mojo b/benches/bench_multiply.mojo index 7810339..27531c9 100644 --- a/benches/bench_multiply.mojo +++ b/benches/bench_multiply.mojo @@ -297,9 +297,24 @@ fn main() raises: log_file, ) + # Case 11: Decimal multiplication with many digits after the decimal point + var case11_a_mojo = Decimal.E() + var case11_b_mojo = Decimal.E05() + var case11_a_py = pydecimal.Decimal("1").exp() + var case11_b_py = pydecimal.Decimal("0.5").exp() + run_benchmark( + "e * e^0.5", + case11_a_mojo, + case11_b_mojo, + case11_a_py, + case11_b_py, + iterations, + log_file, + ) + # Display summary log_print("\n=== Multiplication Benchmark Summary ===", log_file) - log_print("Benchmarked: 10 different multiplication cases", log_file) + log_print("Benchmarked: 11 different multiplication cases", log_file) log_print( "Each case ran: " + String(iterations) + " iterations", log_file ) diff --git a/docs/todo.md b/docs/todo.md new file mode 100644 index 0000000..f5ed6d0 --- /dev/null +++ b/docs/todo.md @@ -0,0 +1,5 @@ +# TODO + +This is a to-do list for Yuhao's personal use. + +- The `exp()` function performs slower than Python's counterpart in specific cases. Detailed investigation reveals the bottleneck stems from multiplication operations between decimals with significant fractional components. These operations currently rely on UInt256 arithmetic, which introduces performance overhead. Optimization of the `multiply()` function is required to address these performance bottlenecks, particularly for high-precision decimal multiplication with many digits after the decimal point. \ No newline at end of file diff --git a/mojoproject.toml b/mojoproject.toml index 6ec5209..93fc70c 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -40,7 +40,8 @@ test_creation = "magic run package && magic run mojo test tests/test_creation.mo test_from_float = "magic run package && magic run mojo test tests/test_from_float.mojo && magic run delete_package" test_from_string = "magic run package && magic run mojo test tests/test_from_string.mojo && magic run delete_package" test_comparison = "magic run package && magic run mojo test tests/test_comparison.mojo && magic run delete_package" - +test_factorial = "magic run package && magic run mojo test tests/test_factorial.mojo && magic run delete_package" +test_exp = "magic run package && magic run mojo test tests/test_exp.mojo && magic run delete_package" # benches bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" @@ -52,6 +53,7 @@ bench_round = "magic run package && cd benches && magic run mojo bench_round.moj bench_from_float = "magic run package && cd benches && magic run mojo bench_from_float.mojo && cd .. && magic run delete_package" bench_from_string = "magic run package && cd benches && magic run mojo bench_from_string.mojo && cd .. && magic run delete_package" bench_comparison = "magic run package && cd benches && magic run mojo bench_comparison.mojo && cd .. && magic run delete_package" +bench_exp = "magic run package && cd benches && magic run mojo bench_exp.mojo && cd .. && magic run delete_package" # before commit final = "magic run test && magic run bench" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 3488cfa..1c66b8e 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimal/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # """ @@ -18,15 +33,13 @@ from .decimal import Decimal from .rounding_mode import RoundingMode -from .maths import ( +from .arithmetics import ( add, subtract, + absolute, + negative, multiply, true_divide, - power, - sqrt, - round, - absolute, ) from .comparison import ( @@ -37,3 +50,11 @@ from .comparison import ( equal, not_equal, ) + +from .exponential import power, sqrt, exp + +from .rounding import round + +from .special import ( + factorial, +) diff --git a/src/decimojo/maths/arithmetics.mojo b/src/decimojo/arithmetics.mojo similarity index 91% rename from src/decimojo/maths/arithmetics.mojo rename to src/decimojo/arithmetics.mojo index 49bae35..32ab77d 100644 --- a/src/decimojo/maths/arithmetics.mojo +++ b/src/decimojo/arithmetics.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # # # Implements basic arithmetic functions for the Decimal type @@ -44,31 +59,59 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: Raises: Error: If the operation would overflow. """ + var x1_coef = x1.coefficient() + var x2_coef = x2.coefficient() + var x1_scale = x1.scale() + var x2_scale = x2.scale() - # Special case for zero - if x1.is_zero(): - return Decimal( - x2.low, - x2.mid, - x2.high, - max(x1.scale(), x2.scale()), - x1.flags & x2.flags == Decimal.SIGN_MASK, - ) + # Special case for zeros + + if x1_coef == 0 and x2_coef == 0: + var scale = max(x1_scale, x2_scale) + return Decimal(0, 0, 0, scale, False) - elif x2.is_zero(): - return Decimal( - x1.low, - x1.mid, - x1.high, - max(x1.scale(), x2.scale()), - x1.flags & x2.flags == Decimal.SIGN_MASK, + elif x1_coef == 0: + var sum_coef = x2_coef + var scale = min( + max(x1_scale, x2_scale), + Decimal.MAX_NUM_DIGITS + - decimojo.utility.number_of_digits(x2.to_uint128()), + ) + ## If x2_coef > 7922816251426433759354395033 + if ( + (x2_coef > Decimal.MAX_AS_UINT128 // 10) + and (scale > 0) + and (scale > x2_scale) + ): + scale -= 1 + sum_coef *= UInt128(10) ** (scale - x2_scale) + var low = UInt32(sum_coef & 0xFFFFFFFF) + var mid = UInt32((sum_coef >> 32) & 0xFFFFFFFF) + var high = UInt32((sum_coef >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, scale, x2.is_negative()) + + elif x2_coef == 0: + var sum_coef = x1_coef + var scale = min( + max(x1_scale, x2_scale), + Decimal.MAX_NUM_DIGITS + - decimojo.utility.number_of_digits(x1.to_uint128()), ) + ## If x1_coef > 7922816251426433759354395033 + if ( + (x1_coef > Decimal.MAX_AS_UINT128 // 10) + and (scale > 0) + and (scale > x1_scale) + ): + scale -= 1 + sum_coef *= UInt128(10) ** (scale - x1_scale) + var low = UInt32(sum_coef & 0xFFFFFFFF) + var mid = UInt32((sum_coef >> 32) & 0xFFFFFFFF) + var high = UInt32((sum_coef >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, scale, x1.is_negative()) # Integer addition with scale of 0 (true integers) - elif x1.scale() == 0 and x2.scale() == 0: - var x1_coef = x1.coefficient() - var x2_coef = x2.coefficient() - + elif x1_scale == 0 and x2_scale == 0: # Same sign: add absolute values and keep the sign if x1.is_negative() == x2.is_negative(): # Add directly using UInt128 arithmetic @@ -118,7 +161,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # Determine the scale for the result var scale = min( - max(x1.scale(), x2.scale()), + max(x1_scale, x2_scale), Decimal.MAX_NUM_DIGITS - decimojo.utility.number_of_digits(summation), ) @@ -138,7 +181,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: else: var diff: UInt128 var is_negative: Bool - if x1.coefficient() > x2.coefficient(): + if x1_coef > x2_coef: diff = x1.to_uint128() - x2.to_uint128() is_negative = x1.is_negative() else: @@ -147,7 +190,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # Determine the scale for the result var scale = min( - max(x1.scale(), x2.scale()), + max(x1_scale, x2_scale), Decimal.MAX_NUM_DIGITS - decimojo.utility.number_of_digits(diff), ) @@ -164,11 +207,11 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: return Decimal(low, mid, high, scale, is_negative) # Float addition with the same scale - elif x1.scale() == x2.scale(): + elif x1_scale == x2_scale: var summation: Int128 # 97-bit signed integer can be stored in Int128 - summation = (-1) ** x1.is_negative() * Int128(x1.coefficient()) + ( + summation = (-1) ** x1.is_negative() * Int128(x1_coef) + ( -1 - ) ** x2.is_negative() * Int128(x2.coefficient()) + ) ** x2.is_negative() * Int128(x2_coef) var is_negative = summation < 0 if is_negative: @@ -180,7 +223,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # If the summation fits in 96 bits, we can use the original scale if summation < Decimal.MAX_AS_INT128: - final_scale = x1.scale() + final_scale = x1_scale # Otherwise, we need to truncate the summation to fit in 96 bits else: @@ -191,7 +234,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: truncated_summation ) - ( decimojo.utility.number_of_digits(summation) - - max(x1.scale(), x2.scale()) + - max(x1_scale, x2_scale) ) # Extract the 32-bit components from the Int256 difference @@ -204,23 +247,21 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # Float addition which with different scales else: var summation: Int256 - if x1.scale() == x2.scale(): - summation = (-1) ** x1.is_negative() * Int256(x1.coefficient()) + ( + if x1_scale == x2_scale: + summation = (-1) ** x1.is_negative() * Int256(x1_coef) + ( -1 - ) ** x2.is_negative() * Int256(x2.coefficient()) - elif x1.scale() > x2.scale(): - summation = (-1) ** x1.is_negative() * Int256(x1.coefficient()) + ( + ) ** x2.is_negative() * Int256(x2_coef) + elif x1_scale > x2_scale: + summation = (-1) ** x1.is_negative() * Int256(x1_coef) + ( -1 - ) ** x2.is_negative() * Int256(x2.coefficient()) * Int256(10) ** ( - x1.scale() - x2.scale() + ) ** x2.is_negative() * Int256(x2_coef) * Int256(10) ** ( + x1_scale - x2_scale ) else: - summation = (-1) ** x1.is_negative() * Int256( - x1.coefficient() - ) * Int256(10) ** (x2.scale() - x1.scale()) + ( - -1 - ) ** x2.is_negative() * Int256( - x2.coefficient() + summation = (-1) ** x1.is_negative() * Int256(x1_coef) * Int256( + 10 + ) ** (x2_scale - x1_scale) + (-1) ** x2.is_negative() * Int256( + x2_coef ) var is_negative = summation < 0 @@ -233,7 +274,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # If the summation fits in 96 bits, we can use the original scale if summation < Decimal.MAX_AS_INT256: - final_scale = max(x1.scale(), x2.scale()) + final_scale = max(x1_scale, x2_scale) # Otherwise, we need to truncate the summation to fit in 96 bits else: @@ -244,7 +285,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: truncated_summation ) - ( decimojo.utility.number_of_digits(summation) - - max(x1.scale(), x2.scale()) + - max(x1_scale, x2_scale) ) # Extract the 32-bit components from the Int256 difference @@ -344,6 +385,13 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: """Combined scale of the two operands.""" var is_negative = x1.is_negative() != x2.is_negative() + # SPECIAL CASE: true one + # Return the other operand + if x1.low == 1 and x1.mid == 0 and x1.high == 0 and x1.flags == 0: + return x2 + if x2.low == 1 and x2.mid == 0 and x2.high == 0 and x2.flags == 0: + return x1 + # SPECIAL CASE: zero # Return zero while preserving the scale if x1_coef == 0 or x2_coef == 0: @@ -436,17 +484,14 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Used to determine the appropriate multiplication method # The coefficient of result would be the sum of the two numbers of bits var x1_num_bits = decimojo.utility.number_of_bits(x1_coef) - """Number of significant bits in the coefficient of x1.""" var x2_num_bits = decimojo.utility.number_of_bits(x2_coef) - """Number of significant bits in the coefficient of x2.""" var combined_num_bits = x1_num_bits + x2_num_bits - """Number of significant bits in the coefficient of the result.""" # SPECIAL CASE: Both operands are true integers if x1_scale == 0 and x2_scale == 0: # Small integers, use UInt64 multiplication if combined_num_bits <= 64: - var prod: UInt64 = UInt64(x1.low) * UInt64(x2.low) + var prod: UInt64 = UInt64(x1_coef) * UInt64(x2_coef) var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) return Decimal(low, mid, 0, 0, is_negative) diff --git a/src/decimojo/comparison.mojo b/src/decimojo/comparison.mojo index 3652ac7..fdbef44 100644 --- a/src/decimojo/comparison.mojo +++ b/src/decimojo/comparison.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # # # Implements comparison operations for the Decimal type diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index ebdd1ed..c534bd1 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # # # Implements basic object methods for the Decimal type @@ -41,8 +56,10 @@ Implements basic object methods for working with decimal numbers. from memory import UnsafePointer +import decimojo.arithmetics import decimojo.comparison -import decimojo.maths +import decimojo.exponential +import decimojo.rounding from decimojo.rounding_mode import RoundingMode import decimojo.utility @@ -117,6 +134,9 @@ struct Decimal( alias NAN_MASK = UInt32(0x00000002) """Not a Number mask. `0b0000_0000_0000_0000_0000_0000_0000_0010`.""" + # TODO: Move these special values to top of the module + # when Mojo support global variables in the future. + # # Special values @staticmethod fn INFINITY() -> Decimal: @@ -188,6 +208,113 @@ struct Decimal( 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, Decimal.SIGN_MASK ) + @staticmethod + fn PI() -> Decimal: + """ + Returns the value of pi (π) as a Decimal. + + Returns: + A Decimal representation of pi with maximum precision. + """ + return Decimal.from_words(0x41B65F29, 0xB143885, 0x6582A536, 0x1C0000) + + @staticmethod + fn E() -> Decimal: + """ + Returns the value of Euler's number (e) as a Decimal. + + Returns: + A Decimal representation of Euler's number with maximum precision. + """ + return Decimal.from_words(0x857AED5A, 0xEBECDE35, 0x57D519AB, 0x1C0000) + + @staticmethod + fn E2() -> Decimal: + return Decimal.from_words(0xE4DFDCAE, 0x89F7E295, 0xEEC0D6E9, 0x1C0000) + + @staticmethod + fn E3() -> Decimal: + return Decimal.from_words(0x236454F7, 0x62055A80, 0x40E65DE2, 0x1B0000) + + @staticmethod + fn E4() -> Decimal: + return Decimal.from_words(0x7121EFD3, 0xFB318FB5, 0xB06A87FB, 0x1B0000) + + @staticmethod + fn E5() -> Decimal: + return Decimal.from_words(0xD99BD974, 0x9F4BE5C7, 0x2FF472E3, 0x1A0000) + + @staticmethod + fn E6() -> Decimal: + return Decimal.from_words(0xADB57A66, 0xBD7A423F, 0x825AD8FF, 0x1A0000) + + @staticmethod + fn E7() -> Decimal: + return Decimal.from_words(0x22313FCF, 0x64D5D12F, 0x236F230A, 0x190000) + + @staticmethod + fn E8() -> Decimal: + return Decimal.from_words(0x1E892E63, 0xD1BF8B5C, 0x6051E812, 0x190000) + + @staticmethod + fn E9() -> Decimal: + return Decimal.from_words(0x34FAB691, 0xE7CD8DEA, 0x1A2EB6C3, 0x180000) + + @staticmethod + fn E10() -> Decimal: + return Decimal.from_words(0xBA7F4F65, 0x58692B62, 0x472BDD8F, 0x180000) + + @staticmethod + fn E11() -> Decimal: + return Decimal.from_words(0x8C2C6D20, 0x2A86F9E7, 0xC176BAAE, 0x180000) + + @staticmethod + fn E12() -> Decimal: + return Decimal.from_words(0xE924992A, 0x31CDC314, 0x3496C2C4, 0x170000) + + @staticmethod + fn E13() -> Decimal: + return Decimal.from_words(0x220130DB, 0xC386029A, 0x8EF393FB, 0x170000) + + @staticmethod + fn E14() -> Decimal: + return Decimal.from_words(0x3A24795C, 0xC412DF01, 0x26DBB5A0, 0x160000) + + @staticmethod + fn E15() -> Decimal: + """Returns the value of e^15 as a Decimal.""" + return Decimal.from_words(0x6C1248BD, 0x90456557, 0x69A0AD8C, 0x160000) + + @staticmethod + fn E16() -> Decimal: + """Returns the value of e^16 as a Decimal.""" + return Decimal.from_words(0xB46A97D, 0x90655BBD, 0x1CB66B18, 0x150000) + + @staticmethod + fn E32() -> Decimal: + """Returns the value of e^32 as a Decimal.""" + return Decimal.from_words(0x18420EB, 0xCC2501E6, 0xFF24A138, 0xF0000) + + @staticmethod + fn E05() -> Decimal: + """Returns the value of e^0.5 = e^(1/2) as a Decimal.""" + return Decimal.from_words(0x8E99DD66, 0xC210E35C, 0x3545E717, 0x1C0000) + + @staticmethod + fn E025() -> Decimal: + """Returns the value of e^0.25 = e^(1/4) as a Decimal.""" + return Decimal.from_words(0xB43646F1, 0x2654858A, 0x297D3595, 0x1C0000) + + @staticmethod + fn LN10() -> Decimal: + """ + Returns the natural logarithm of 10 as a Decimal. + + Returns: + A Decimal representation of ln(10) with maximum precision. + """ + return Decimal.from_words(0x9FA69733, 0x1414B220, 0x4A668998, 0x1C0000) + # ===------------------------------------------------------------------=== # # Constructors and life time dunder methods # ===------------------------------------------------------------------=== # @@ -816,6 +943,23 @@ struct Decimal( """ return 'Decimal("' + self.__str__() + '")' + fn repr_from_words(self) -> String: + """ + Returns a string representation of the Decimal's internal words. + Decimal.from_words(low, mid, high, flags). + """ + return ( + "Decimal.from_words(" + + hex(self.low) + + ", " + + hex(self.mid) + + ", " + + hex(self.high) + + ", " + + hex(self.flags) + + ")" + ) + fn to_int128(self) -> Int128: """ Returns the signed integral part of the Decimal. @@ -871,7 +1015,7 @@ struct Decimal( The absolute value of this Decimal. """ - return decimojo.maths.absolute(self) + return decimojo.arithmetics.absolute(self) fn __neg__(self) -> Self: """ @@ -881,7 +1025,7 @@ struct Decimal( The negation of this Decimal. """ - return decimojo.maths.negative(self) + return decimojo.arithmetics.negative(self) # ===------------------------------------------------------------------=== # # Basic binary arithmetic operation dunders @@ -902,21 +1046,21 @@ struct Decimal( """ try: - return decimojo.maths.add(self, other) + return decimojo.arithmetics.add(self, other) except e: - raise Error("Error in `__add__()`; ", e) + raise Error("Error in `__add__()`: ", e) fn __add__(self, other: Float64) raises -> Self: - return decimojo.maths.add(self, Decimal(other)) + return decimojo.arithmetics.add(self, Decimal(other)) fn __add__(self, other: Int) raises -> Self: - return decimojo.maths.add(self, Decimal(other)) + return decimojo.arithmetics.add(self, Decimal(other)) fn __radd__(self, other: Float64) raises -> Self: - return decimojo.maths.add(Decimal(other), self) + return decimojo.arithmetics.add(Decimal(other), self) fn __radd__(self, other: Int) raises -> Self: - return decimojo.maths.add(Decimal(other), self) + return decimojo.arithmetics.add(Decimal(other), self) fn __sub__(self, other: Decimal) raises -> Self: """ @@ -941,52 +1085,52 @@ struct Decimal( """ try: - return decimojo.maths.subtract(self, other) + return decimojo.arithmetics.subtract(self, other) except e: - raise Error("Error in `__sub__()`; ", e) + raise Error("Error in `__sub__()`: ", e) fn __sub__(self, other: Float64) raises -> Self: - return decimojo.maths.subtract(self, Decimal(other)) + return decimojo.arithmetics.subtract(self, Decimal(other)) fn __sub__(self, other: Int) raises -> Self: - return decimojo.maths.subtract(self, Decimal(other)) + return decimojo.arithmetics.subtract(self, Decimal(other)) fn __rsub__(self, other: Float64) raises -> Self: - return decimojo.maths.subtract(Decimal(other), self) + return decimojo.arithmetics.subtract(Decimal(other), self) fn __rsub__(self, other: Int) raises -> Self: - return decimojo.maths.subtract(Decimal(other), self) + return decimojo.arithmetics.subtract(Decimal(other), self) fn __mul__(self, other: Decimal) raises -> Self: """ Multiplies two Decimal values and returns a new Decimal containing the product. """ - return decimojo.maths.multiply(self, other) + return decimojo.arithmetics.multiply(self, other) fn __mul__(self, other: Float64) raises -> Self: - return decimojo.maths.multiply(self, Decimal(other)) + return decimojo.arithmetics.multiply(self, Decimal(other)) fn __mul__(self, other: Int) raises -> Self: - return decimojo.maths.multiply(self, Decimal(other)) + return decimojo.arithmetics.multiply(self, Decimal(other)) fn __truediv__(self, other: Decimal) raises -> Self: """ Divides this Decimal by another Decimal and returns a new Decimal containing the result. """ - return decimojo.maths.true_divide(self, other) + return decimojo.arithmetics.true_divide(self, other) fn __truediv__(self, other: Float64) raises -> Self: - return decimojo.maths.true_divide(self, Decimal(other)) + return decimojo.arithmetics.true_divide(self, Decimal(other)) fn __truediv__(self, other: Int) raises -> Self: - return decimojo.maths.true_divide(self, Decimal(other)) + return decimojo.arithmetics.true_divide(self, Decimal(other)) fn __rtruediv__(self, other: Float64) raises -> Self: - return decimojo.maths.true_divide(Decimal(other), self) + return decimojo.arithmetics.true_divide(Decimal(other), self) fn __rtruediv__(self, other: Int) raises -> Self: - return decimojo.maths.true_divide(Decimal(other), self) + return decimojo.arithmetics.true_divide(Decimal(other), self) fn __pow__(self, exponent: Decimal) raises -> Self: """ @@ -1106,7 +1250,7 @@ struct Decimal( """ try: - return decimojo.maths.round( + return decimojo.rounding.round( self, ndigits=ndigits, rounding_mode=RoundingMode.HALF_EVEN() ) except e: @@ -1119,9 +1263,22 @@ struct Decimal( # ===------------------------------------------------------------------=== # # Mathematical methods that do not implement a trait (not a dunder) - # round, sqrt + # exp, round, sqrt # ===------------------------------------------------------------------=== # + fn exp(self) raises -> Self: + """ + Calculates the exponential of this Decimal. + + Returns: + The exponential of this Decimal. + """ + + try: + return decimojo.exponential.exp(self) + except e: + raise Error("Error in `Decimal.exp()`: ", e) + fn round( self, ndigits: Int = 0, @@ -1147,11 +1304,11 @@ struct Decimal( """ try: - return decimojo.maths.round( + return decimojo.rounding.round( self, ndigits=ndigits, rounding_mode=rounding_mode ) except e: - raise Error("Error in `Decimal.round()`; ", e) + raise Error("Error in `Decimal.round()`: ", e) fn sqrt(self) raises -> Self: """ @@ -1164,7 +1321,7 @@ struct Decimal( Error: If the operation would result in overflow. """ - return decimojo.maths.sqrt(self) + return decimojo.exponential.sqrt(self) # ===------------------------------------------------------------------=== # # Other methods @@ -1305,8 +1462,8 @@ struct Decimal( - Zero if |self| = |other| - Negative value if |self| < |other| """ - var abs_self = decimojo.maths.absolute(self) - var abs_other = decimojo.maths.absolute(other) + var abs_self = decimojo.arithmetics.absolute(self) + var abs_other = decimojo.arithmetics.absolute(other) if abs_self > abs_other: return 1 diff --git a/src/decimojo/maths/exp.mojo b/src/decimojo/exponential.mojo similarity index 51% rename from src/decimojo/maths/exp.mojo rename to src/decimojo/exponential.mojo index 6a8193b..b17a635 100644 --- a/src/decimojo/maths/exp.mojo +++ b/src/decimojo/exponential.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # # # Implements exponential functions for the Decimal type @@ -19,6 +34,7 @@ import math as builtin_math import testing +import decimojo.special import decimojo.utility @@ -47,18 +63,33 @@ fn power(base: Decimal, exponent: Decimal) raises -> Decimal: # Convert exponent to integer var exp_value = Int(exponent) + return power(base, exp_value) + + +fn power(base: Decimal, exponent: Int) raises -> Decimal: + """ + Convenience method to raise base to an integer power. + + Args: + base: The base value. + exponent: The integer power to raise base to. + + Returns: + A new Decimal containing the result. + """ + # Special cases - if exp_value == 0: + if exponent == 0: # x^0 = 1 (including 0^0 = 1 by convention) return Decimal.ONE() - if exp_value == 1: + if exponent == 1: # x^1 = x return base if base.is_zero(): # 0^n = 0 for n > 0 - if exp_value > 0: + if exponent > 0: return Decimal.ZERO() else: # 0^n is undefined for n < 0 @@ -69,21 +100,22 @@ fn power(base: Decimal, exponent: Decimal) raises -> Decimal: return Decimal.ONE() # Handle negative exponents: x^(-n) = 1/(x^n) - var negative_exponent = exp_value < 0 + var negative_exponent = exponent < 0 + var abs_exp = exponent if negative_exponent: - exp_value = -exp_value + abs_exp = -exponent # Binary exponentiation for efficiency var result = Decimal.ONE() var current_base = base - while exp_value > 0: - if exp_value & 1: # exp_value is odd + while abs_exp > 0: + if abs_exp & 1: # exp_value is odd result = result * current_base - exp_value >>= 1 # exp_value = exp_value / 2 + abs_exp >>= 1 # exp_value = exp_value / 2 - if exp_value > 0: + if abs_exp > 0: current_base = current_base * current_base # For negative exponents, take the reciprocal @@ -94,20 +126,6 @@ fn power(base: Decimal, exponent: Decimal) raises -> Decimal: return result -fn power(base: Decimal, exponent: Int) raises -> Decimal: - """ - Convenience method to raise base to an integer power. - - Args: - base: The base value. - exponent: The integer power to raise base to. - - Returns: - A new Decimal containing the result. - """ - return power(base, Decimal(exponent)) - - fn sqrt(x: Decimal) raises -> Decimal: """ Computes the square root of a Decimal value using Newton-Raphson method. @@ -229,3 +247,199 @@ fn sqrt(x: Decimal) raises -> Decimal: ) return guess + + +fn exp(x: Decimal) raises -> Decimal: + """ + Calculates e^x for any Decimal value using optimized range reduction. + x should be no greater than 66 to avoid overflow. + + Args: + x: The exponent. + + Returns: + A Decimal approximation of e^x. + + Notes: + Because ln(2^96-1) ~= 66.54212933375474970405428366, + the x value should be no greater than 66 to avoid overflow. + """ + + # Handle special cases + if x.is_zero(): + return Decimal.ONE() + + if x.is_negative(): + return Decimal.ONE() / exp(-x) + + # For x < 1, use Taylor series expansion + # For x > 1, use optimized range reduction with smaller chunks + # Yuhao's notes: + # e^50 is more accurate than (e^2)^25 if e^2 needs to be approximated + # because estimating e^x would introduce errors + # e^50 is less accurate than (e^2)^25 if e^2 is precomputed + # because too many multiplications would introduce errors + # So we need to find a way to reduce both the number of multiplications + # and the error introduced by approximating e^x + # This helps improve accuracy as well as speed. + # My solution is to factorize x into a combination of integers and + # a fractional part smaller than 1. + # Then use precomputed e^integer values to calculate e^x + # For example, e^59.12 = (e^50)^1 * (e^5)^1 * (e^2)^2 * e^0.12 + # This way, we just need to do 4 multiplications instead of 59. + # The fractional part is then calculated using the series expansion. + # Because the fractional part is <1, the series converges quickly. + + var exp_chunk: Decimal + var remainder: Decimal + var num_chunks: Int = 1 + var x_int = Int(x) + + if x.is_one(): + return Decimal.E() + + elif x_int < 1: + var d05 = Decimal(5, 0, 0, scale=1, sign=False) # 0.5 + var d025 = Decimal(25, 0, 0, scale=2, sign=False) # 0.25 + + if x < d025: # 0 < x < 0.25 + return exp_series(x) + + elif x < d05: # 0.25 <= x < 0.5 + exp_chunk = Decimal.E025() + remainder = x - d025 + + else: # 0.5 <= x < 1 + exp_chunk = Decimal.E05() + remainder = x - d05 + + elif x_int == 1: # 1 <= x < 2, chunk = 1 + exp_chunk = Decimal.E() + remainder = x - x_int + + elif x_int == 2: # 2 <= x < 3, chunk = 2 + exp_chunk = Decimal.E2() + remainder = x - x_int + + elif x_int == 3: # 3 <= x < 4, chunk = 3 + exp_chunk = Decimal.E3() + remainder = x - x_int + + elif x_int == 4: # 4 <= x < 5, chunk = 4 + exp_chunk = Decimal.E4() + remainder = x - x_int + + elif x_int == 5: # 5 <= x < 6, chunk = 5 + exp_chunk = Decimal.E5() + remainder = x - x_int + + elif x_int == 6: # 6 <= x < 7, chunk = 6 + exp_chunk = Decimal.E6() + remainder = x - x_int + + elif x_int == 7: # 7 <= x < 8, chunk = 7 + exp_chunk = Decimal.E7() + remainder = x - x_int + + elif x_int == 8: # 8 <= x < 9, chunk = 8 + exp_chunk = Decimal.E8() + remainder = x - x_int + + elif x_int == 9: # 9 <= x < 10, chunk = 9 + exp_chunk = Decimal.E9() + remainder = x - x_int + + elif x_int == 10: # 10 <= x < 11, chunk = 10 + exp_chunk = Decimal.E10() + remainder = x - x_int + + elif x_int == 11: # 11 <= x < 12, chunk = 11 + exp_chunk = Decimal.E11() + remainder = x - x_int + + elif x_int == 12: # 12 <= x < 13, chunk = 12 + exp_chunk = Decimal.E12() + remainder = x - x_int + + elif x_int == 13: # 13 <= x < 14, chunk = 13 + exp_chunk = Decimal.E13() + remainder = x - x_int + + elif x_int == 14: # 14 <= x < 15, chunk = 14 + exp_chunk = Decimal.E14() + remainder = x - x_int + + elif x_int == 15: # 15 <= x < 16, chunk = 15 + exp_chunk = Decimal.E15() + remainder = x - x_int + + elif x_int < 32: # 16 <= x < 32, chunk = 16 + num_chunks = x_int >> 4 + exp_chunk = Decimal.E16() + remainder = x - (num_chunks << 4) + + else: # chunk = 32 + num_chunks = x_int >> 5 + exp_chunk = Decimal.E32() + remainder = x - (num_chunks << 5) + + # Calculate e^(chunk * num_chunks) = (e^chunk)^num_chunks + var exp_main = power(exp_chunk, num_chunks) + + # Calculate e^remainder by calling exp() again + # If it is <1, then use Taylor's series + var exp_remainder = exp(remainder) + + # Combine: e^x = e^(main+remainder) = e^main * e^remainder + return exp_main * exp_remainder + + +fn exp_series(x: Decimal) raises -> Decimal: + """ + Calculates e^x using Taylor series expansion. + Do not use this function for values larger than 1, but `exp()` instead. + + Args: + x: The exponent. + + Returns: + A Decimal approximation of e^x. + + Notes: + + Sum terms of Taylor series: e^x = 1 + x + x²/2! + x³/3! + ... + Because ln(2^96-1) ~= 66.54212933375474970405428366, + the x value should be no greater than 66 to avoid overflow. + """ + + var max_terms = 500 + + # For x=0, e^0 = 1 + if x.is_zero(): + return Decimal.ONE() + + # For x with very small magnitude, just use 1+x approximation + if abs(x) == Decimal("1e-28"): + return Decimal.ONE() + x + + # Initialize result and term + var result = Decimal.ONE() + var term = Decimal.ONE() + var term_add_on: Decimal + + # Calculate terms iteratively + # term[x] = x^i / i! + # term[x-1] = x^{i-1} / (i-1)! + # => term[x] / term[x-1] = x / i + + for i in range(1, max_terms + 1): + term_add_on = x / Decimal(i) + + term = term * term_add_on + # Check for convergence + if term.is_zero(): + break + + result = result + term + + return result diff --git a/src/decimojo/maths/__init__.mojo b/src/decimojo/maths/__init__.mojo deleted file mode 100644 index f44a61a..0000000 --- a/src/decimojo/maths/__init__.mojo +++ /dev/null @@ -1,37 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE -# ===----------------------------------------------------------------------=== # -# -# Implements basic object methods for the Decimal type -# which supports correctly-rounded, fixed-point arithmetic. -# -# ===----------------------------------------------------------------------=== # -# -# TODO Additional functions planned for future implementation: -# -# root(x: Decimal, n: Int): Computes the nth root of x using Newton's method -# exp(x: Decimal): Computes e raised to the power of x -# ln(x: Decimal): Computes the natural logarithm of x -# log10(x: Decimal): Computes the base-10 logarithm of x -# sin(x: Decimal): Computes the sine of x (in radians) -# cos(x: Decimal): Computes the cosine of x (in radians) -# tan(x: Decimal): Computes the tangent of x (in radians) -# abs(x: Decimal): Returns the absolute value of x -# floor(x: Decimal): Returns the largest integer <= x -# ceil(x: Decimal): Returns the smallest integer >= x -# gcd(a: Decimal, b: Decimal): Returns greatest common divisor of a and b -# lcm(a: Decimal, b: Decimal): Returns least common multiple of a and b -# ===----------------------------------------------------------------------=== # - -from .arithmetics import ( - add, - subtract, - negative, - absolute, - multiply, - true_divide, -) -from .exp import power, sqrt -from .rounding import round diff --git a/src/decimojo/prelude.mojo b/src/decimojo/prelude.mojo index b4f070a..c1e992f 100644 --- a/src/decimojo/prelude.mojo +++ b/src/decimojo/prelude.mojo @@ -1,3 +1,24 @@ +# ===----------------------------------------------------------------------=== # +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===----------------------------------------------------------------------=== # + """ Provides a list of things that can be imported at one time. The list contains the functions or types that are the most essential for a user. diff --git a/src/decimojo/maths/rounding.mojo b/src/decimojo/rounding.mojo similarity index 89% rename from src/decimojo/maths/rounding.mojo rename to src/decimojo/rounding.mojo index 6920b51..b9960aa 100644 --- a/src/decimojo/maths/rounding.mojo +++ b/src/decimojo/rounding.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # # # Implements basic object methods for the Decimal type diff --git a/src/decimojo/rounding_mode.mojo b/src/decimojo/rounding_mode.mojo index 6bf1968..d1cb629 100644 --- a/src/decimojo/rounding_mode.mojo +++ b/src/decimojo/rounding_mode.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # diff --git a/src/decimojo/special.mojo b/src/decimojo/special.mojo new file mode 100644 index 0000000..3730398 --- /dev/null +++ b/src/decimojo/special.mojo @@ -0,0 +1,281 @@ +# ===----------------------------------------------------------------------=== # +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===----------------------------------------------------------------------=== # +# +# Implements special functions for the Decimal type +# +# ===----------------------------------------------------------------------=== # + +"""Implements functions for special operations on Decimal objects.""" + + +fn factorial(n: Int) raises -> Decimal: + """Calculates the factorial of a non-negative integer. + + Args: + n: The non-negative integer to calculate the factorial of. + + Returns: + The factorial of n. + + Notes: + + 27! is the largest factorial that can be represented by Decimal. + An error will be raised if n is greater than 27. + """ + + if n < 0: + raise Error("Factorial is not defined for negative numbers") + + if n > 27: + raise Error("{}! is too large to be represented by Decimal".format(n)) + + # Directly return the factorial for n = 0 to 27 + if n == 0 or n == 1: + return Decimal.from_words(1, 0, 0, 0) # 1 + elif n == 2: + return Decimal.from_words(2, 0, 0, 0) # 2 + elif n == 3: + return Decimal.from_words(6, 0, 0, 0) # 6 + elif n == 4: + return Decimal.from_words(24, 0, 0, 0) # 24 + elif n == 5: + return Decimal.from_words(120, 0, 0, 0) # 120 + elif n == 6: + return Decimal.from_words(720, 0, 0, 0) # 720 + elif n == 7: + return Decimal.from_words(5040, 0, 0, 0) # 5040 + elif n == 8: + return Decimal.from_words(40320, 0, 0, 0) # 40320 + elif n == 9: + return Decimal.from_words(362880, 0, 0, 0) # 362880 + elif n == 10: + return Decimal.from_words(3628800, 0, 0, 0) # 3628800 + elif n == 11: + return Decimal.from_words(39916800, 0, 0, 0) # 39916800 + elif n == 12: + return Decimal.from_words(479001600, 0, 0, 0) # 479001600 + elif n == 13: + return Decimal.from_words(1932053504, 1, 0, 0) # 6227020800 + elif n == 14: + return Decimal.from_words(1278945280, 20, 0, 0) # 87178291200 + elif n == 15: + return Decimal.from_words(2004310016, 304, 0, 0) # 1307674368000 + elif n == 16: + return Decimal.from_words(2004189184, 4871, 0, 0) # 20922789888000 + elif n == 17: + return Decimal.from_words(4006445056, 82814, 0, 0) # 355687428096000 + elif n == 18: + return Decimal.from_words(3396534272, 1490668, 0, 0) # 6402373705728000 + elif n == 19: + return Decimal.from_words( + 109641728, 28322707, 0, 0 + ) # 121645100408832000 + elif n == 20: + return Decimal.from_words( + 2192834560, 566454140, 0, 0 + ) # 2432902008176640000 + elif n == 21: + return Decimal.from_words( + 3099852800, 3305602358, 2, 0 + ) # 51090942171709440000 + elif n == 22: + return Decimal.from_words( + 3772252160, 4003775155, 60, 0 + ) # 1124000727777607680000 + elif n == 23: + return Decimal.from_words( + 862453760, 1892515369, 1401, 0 + ) # 25852016738884976640000 + elif n == 24: + return Decimal.from_words( + 3519021056, 2470695900, 33634, 0 + ) # 620448401733239439360000 + elif n == 25: + return Decimal.from_words( + 2076180480, 1637855376, 840864, 0 + ) # 15511210043330985984000000 + elif n == 26: + return Decimal.from_words( + 2441084928, 3929534124, 21862473, 0 + ) # 403291461126605650322784000 + else: + return Decimal.from_words( + 1484783616, 3018206259, 590286795, 0 + ) # 10888869450418352160768000000 + + +fn factorial_reciprocal(n: Int) raises -> Decimal: + """Calculates the reciprocal of factorial of a non-negative integer (1/n!). + + Args: + n: The non-negative integer to calculate the reciprocal factorial of. + + Returns: + The reciprocal of factorial of n (1/n!). + + Notes: + This function is optimized for Taylor series calculations. + The function uses pre-computed values for speed. + For n > 27, the result is effectively 0 at Decimal precision. + """ + + # 1/0! = 1, Decimal.from_words(0x1, 0x0, 0x0, 0x0) + # 1/1! = 1, Decimal.from_words(0x1, 0x0, 0x0, 0x0) + # 1/2! = 0.5, Decimal.from_words(0x5, 0x0, 0x0, 0x10000) + # 1/3! = 0.1666666666666666666666666667, Decimal.from_words(0x82aaaaab, 0xa5b8065, 0x562a265, 0x1c0000) + # 1/4! = 0.0416666666666666666666666667, Decimal.from_words(0x60aaaaab, 0x4296e019, 0x158a899, 0x1c0000) + # 1/5! = 0.0083333333333333333333333333, Decimal.from_words(0x13555555, 0xd516005, 0x44ee85, 0x1c0000) + # 1/6! = 0.0013888888888888888888888889, Decimal.from_words(0x2de38e39, 0x2ce2e556, 0xb7d16, 0x1c0000) + # 1/7! = 0.0001984126984126984126984127, Decimal.from_words(0xe1fbefbf, 0xbd44fc30, 0x1a427, 0x1c0000) + # 1/8! = 0.0000248015873015873015873016, Decimal.from_words(0x1c3f7df8, 0xf7a89f86, 0x3484, 0x1c0000) + # 1/9! = 0.0000027557319223985890652557, Decimal.from_words(0xca3ff18d, 0xe2a0f547, 0x5d5, 0x1c0000) + # 1/10! = 0.0000002755731922398589065256, Decimal.from_words(0x94399828, 0x63767eed, 0x95, 0x1c0000) + # 1/11! = 0.0000000250521083854417187751, Decimal.from_words(0xb06253a7, 0x94adae72, 0xd, 0x1c0000) + # 1/12! = 0.0000000020876756987868098979, Decimal.from_words(0xe40831a3, 0x21b923de, 0x1, 0x1c0000) + # 1/13! = 0.0000000001605904383682161460, Decimal.from_words(0x4c9e2b34, 0x16495187, 0x0, 0x1c0000) + # 1/14! = 0.0000000000114707455977297247, Decimal.from_words(0xce9d955f, 0x19785d2, 0x0, 0x1c0000) + # 1/15! = 0.0000000000007647163731819816, Decimal.from_words(0xdc63d28, 0x1b2b0e, 0x0, 0x1c0000) + # 1/16! = 0.0000000000000477947733238739, Decimal.from_words(0xe0dc63d3, 0x1b2b0, 0x0, 0x1c0000) + # 1/17! = 0.0000000000000028114572543455, Decimal.from_words(0xef1c05df, 0x1991, 0x0, 0x1c0000) + # 1/18! = 0.0000000000000001561920696859, Decimal.from_words(0xa9ba721b, 0x16b, 0x0, 0x1c0000) + # 1/19! = 0.0000000000000000082206352466, Decimal.from_words(0x23e16452, 0x13, 0x0, 0x1c0000) + # 1/20! = 0.0000000000000000004110317623, Decimal.from_words(0xf4fe7837, 0x0, 0x0, 0x1c0000) + # 1/21! = 0.0000000000000000000195729411, Decimal.from_words(0xbaa9803, 0x0, 0x0, 0x1c0000) + # 1/22! = 0.0000000000000000000008896791, Decimal.from_words(0x87c117, 0x0, 0x0, 0x1c0000) + # 1/23! = 0.0000000000000000000000386817, Decimal.from_words(0x5e701, 0x0, 0x0, 0x1c0000) + # 1/24! = 0.0000000000000000000000016117, Decimal.from_words(0x3ef5, 0x0, 0x0, 0x1c0000) + # 1/25! = 0.0000000000000000000000000645, Decimal.from_words(0x285, 0x0, 0x0, 0x1c0000) + # 1/26! = 0.0000000000000000000000000025, Decimal.from_words(0x19, 0x0, 0x0, 0x1c0000) + # 1/27! = 0.0000000000000000000000000001, Decimal.from_words(0x1, 0x0, 0x0, 0x1c0000) + + if n < 0: + raise Error("Factorial reciprocal is not defined for negative numbers") + + # For n > 27, 1/n! is essentially 0 at Decimal precision + # Return 0 with max scale + if n > 27: + return Decimal.from_words(0, 0, 0, 0x001C0000) + + # Directly return pre-computed reciprocal factorials + if n == 0 or n == 1: + return Decimal.from_words(0x1, 0x0, 0x0, 0x0) # 1 + elif n == 2: + return Decimal.from_words(0x5, 0x0, 0x0, 0x10000) # 0.5 + elif n == 3: + return Decimal.from_words( + 0x82AAAAAB, 0xA5B8065, 0x562A265, 0x1C0000 + ) # 0.1666... + elif n == 4: + return Decimal.from_words( + 0x60AAAAAB, 0x4296E019, 0x158A899, 0x1C0000 + ) # 0.0416... + elif n == 5: + return Decimal.from_words( + 0x13555555, 0xD516005, 0x44EE85, 0x1C0000 + ) # 0.0083... + elif n == 6: + return Decimal.from_words( + 0x2DE38E39, 0x2CE2E556, 0xB7D16, 0x1C0000 + ) # 0.0013... + elif n == 7: + return Decimal.from_words( + 0xE1FBEFBF, 0xBD44FC30, 0x1A427, 0x1C0000 + ) # 0.0001... + elif n == 8: + return Decimal.from_words( + 0x1C3F7DF8, 0xF7A89F86, 0x3484, 0x1C0000 + ) # 0.0000248... + elif n == 9: + return Decimal.from_words( + 0xCA3FF18D, 0xE2A0F547, 0x5D5, 0x1C0000 + ) # 0.0000027... + elif n == 10: + return Decimal.from_words( + 0x94399828, 0x63767EED, 0x95, 0x1C0000 + ) # 0.00000027... + elif n == 11: + return Decimal.from_words( + 0xB06253A7, 0x94ADAE72, 0xD, 0x1C0000 + ) # 0.000000025... + elif n == 12: + return Decimal.from_words( + 0xE40831A3, 0x21B923DE, 0x1, 0x1C0000 + ) # 0.0000000020... + elif n == 13: + return Decimal.from_words( + 0x4C9E2B34, 0x16495187, 0x0, 0x1C0000 + ) # 0.0000000001... + elif n == 14: + return Decimal.from_words( + 0xCE9D955F, 0x19785D2, 0x0, 0x1C0000 + ) # 0.00000000001... + elif n == 15: + return Decimal.from_words( + 0xDC63D28, 0x1B2B0E, 0x0, 0x1C0000 + ) # 0.0000000000007... + elif n == 16: + return Decimal.from_words( + 0xE0DC63D3, 0x1B2B0, 0x0, 0x1C0000 + ) # 0.00000000000004... + elif n == 17: + return Decimal.from_words( + 0xEF1C05DF, 0x1991, 0x0, 0x1C0000 + ) # 0.000000000000002... + elif n == 18: + return Decimal.from_words( + 0xA9BA721B, 0x16B, 0x0, 0x1C0000 + ) # 0.0000000000000001... + elif n == 19: + return Decimal.from_words( + 0x23E16452, 0x13, 0x0, 0x1C0000 + ) # 0.0000000000000000082... + elif n == 20: + return Decimal.from_words( + 0xF4FE7837, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000004... + elif n == 21: + return Decimal.from_words( + 0xBAA9803, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000195... + elif n == 22: + return Decimal.from_words( + 0x87C117, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000008... + elif n == 23: + return Decimal.from_words( + 0x5E701, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000000386... + elif n == 24: + return Decimal.from_words( + 0x3EF5, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000000016... + elif n == 25: + return Decimal.from_words( + 0x285, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000000000645 + elif n == 26: + return Decimal.from_words( + 0x19, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000000000025 + else: # n == 27 + return Decimal.from_words( + 0x1, 0x0, 0x0, 0x1C0000 + ) # 0.0000000000000000000000000001 diff --git a/src/decimojo/utility.mojo b/src/decimojo/utility.mojo index 565b97f..95aa6a2 100644 --- a/src/decimojo/utility.mojo +++ b/src/decimojo/utility.mojo @@ -1,7 +1,22 @@ # ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE +# +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forFudan/DeciMojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # ===----------------------------------------------------------------------=== # # # Implements internal utility functions for the Decimal type diff --git a/tests/test_arithmetics.mojo b/tests/test_arithmetics.mojo index ade4bd5..9162f2b 100644 --- a/tests/test_arithmetics.mojo +++ b/tests/test_arithmetics.mojo @@ -159,6 +159,19 @@ fn test_add() raises: print("Decimal addition tests passed!") + # Test case 18: Edge case with one equals 0 + var a18 = Decimal("4563117171088016.3026696499898") + var b18 = Decimal("0.0000000000000000000000000000") + var result18 = a18 + b18 + # 1234567890123456789.0123456789 + testing.assert_equal( + String(result18), + "4563117171088016.3026696499898", + "Addition with zeros", + ) + + print("Decimal addition tests passed!") + fn test_negation() raises: print("------------------------------------------------------") diff --git a/tests/test_exp.mojo b/tests/test_exp.mojo new file mode 100644 index 0000000..6e259fa --- /dev/null +++ b/tests/test_exp.mojo @@ -0,0 +1,360 @@ +""" +Comprehensive tests for the exp() function in the DeciMojo library. +Tests various cases including basic values, mathematical identities, +and edge cases to ensure proper calculation of e^x. +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode +from decimojo.exponential import exp + + +fn test_basic_exp_values() raises: + """Test basic exponential function values.""" + print("Testing basic exponential values...") + + # Test case 1: e^0 = 1 + var zero = Decimal(String("0")) + var result0 = exp(zero) + testing.assert_equal( + String(result0), String("1"), "e^0 should be 1, got " + String(result0) + ) + + # Test case 2: e^1 should be close to Euler's number + var one = Decimal(String("1")) + var result1 = exp(one) + var expected1 = String( + "2.718281828459045235360287471" + ) # e to 27 decimal places + testing.assert_true( + String(result1).startswith(expected1[0:25]), + "e^1 should be approximately " + + String(expected1) + + ", got " + + String(result1), + ) + + # Test case 3: e^2 + var two = Decimal(String("2")) + var result2 = exp(two) + var expected2 = String( + "7.389056098930650227230427461" + ) # e^2 to 27 decimal places + testing.assert_true( + String(result2).startswith(expected2[0:25]), + "e^2 should be approximately " + + String(expected2) + + ", got " + + String(result2), + ) + + # Test case 4: e^3 + var three = Decimal(String("3")) + var result3 = exp(three) + var expected3 = String( + "20.08553692318766774092852965" + ) # e^3 to 27 decimal places + testing.assert_true( + String(result3).startswith(expected3[0:25]), + "e^3 should be approximately " + + String(expected3) + + ", got " + + String(result3), + ) + + # Test case 5: e^5 + var five = Decimal(String("5")) + var result5 = exp(five) + var expected5 = String( + "148.41315910257660342111558004055" + ) # e^5 to 27 decimal places + testing.assert_true( + String(result5).startswith(expected5[0:25]), + "e^5 should be approximately " + + String(expected5) + + ", got " + + String(result5), + ) + + print("✓ Basic exponential values tests passed!") + + +fn test_negative_exponents() raises: + """Test exponential function with negative exponents.""" + print("Testing exponential function with negative exponents...") + + # Test case 1: e^(-1) = 1/e + var neg_one = Decimal(String("-1")) + var result1 = exp(neg_one) + var expected1 = String( + "0.3678794411714423215955237702" + ) # e^-1 to 27 decimal places + testing.assert_true( + String(result1).startswith(expected1[0:25]), + "e^-1 should be approximately " + + String(expected1) + + ", got " + + String(result1), + ) + + # Test case 2: e^(-2) = 1/e^2 + var neg_two = Decimal(String("-2")) + var result2 = exp(neg_two) + var expected2 = String( + "0.1353352832366126918939994950" + ) # e^-2 to 27 decimal places + testing.assert_true( + String(result2).startswith(expected2[0:25]), + "e^-2 should be approximately " + + String(expected2) + + ", got " + + String(result2), + ) + + # Test case 3: e^(-5) + var neg_five = Decimal(String("-5")) + var result3 = exp(neg_five) + var expected3 = String( + "0.006737946999085467096636048777" + ) # e^-5 to 27 decimal places + testing.assert_true( + String(result3).startswith(expected3[0:25]), + "e^-5 should be approximately " + + String(expected3) + + ", got " + + String(result3), + ) + + print("✓ Negative exponents tests passed!") + + +fn test_fractional_exponents() raises: + """Test exponential function with fractional exponents.""" + print("Testing exponential function with fractional exponents...") + + # Test case 1: e^0.5 + var half = Decimal(String("0.5")) + var result1 = exp(half) + var expected1 = String( + "1.648721270700128146848650787" + ) # e^0.5 to 27 decimal places + testing.assert_true( + String(result1).startswith(expected1[0:25]), + "e^0.5 should be approximately " + + String(expected1) + + ", got " + + String(result1), + ) + + # Test case 2: e^0.1 + var tenth = Decimal(String("0.1")) + var result2 = exp(tenth) + var expected2 = String( + "1.105170918075647624811707826" + ) # e^0.1 to 27 decimal places + testing.assert_true( + String(result2).startswith(expected2[0:25]), + "e^0.1 should be approximately " + + String(expected2) + + ", got " + + String(result2), + ) + + # Test case 3: e^(-0.5) + var neg_half = Decimal(String("-0.5")) + var result3 = exp(neg_half) + var expected3 = String( + "0.6065306597126334236037995349" + ) # e^-0.5 to 27 decimal places + testing.assert_true( + String(result3).startswith(expected3[0:25]), + "e^-0.5 should be approximately " + + String(expected3) + + ", got " + + String(result3), + ) + + # Test case 4: e^1.5 + var one_half = Decimal(String("1.5")) + var result4 = exp(one_half) + var expected4 = String( + "4.481689070338064822602055460" + ) # e^1.5 to 27 decimal places + testing.assert_true( + String(result4).startswith(expected4[0:25]), + "e^1.5 should be approximately " + + String(expected4) + + ", got " + + String(result4), + ) + + print("✓ Fractional exponents tests passed!") + + +fn test_high_precision_exponents() raises: + """Test exponential function with high precision inputs.""" + print("Testing exponential function with high precision inputs...") + + # Test case 1: e^π (approximate) + var pi = Decimal(String("3.14159265358979323846264338327950288")) + var result1 = exp(pi) + var expected1 = String( + "23.14069263277926900572908636794" + ) # e^pi to 27 decimal places + testing.assert_true( + String(result1).startswith(expected1[0:25]), + "e^pi should be approximately " + + String(expected1) + + ", got " + + String(result1), + ) + + # Test case 2: e^2.71828 (approximate e) + var approx_e = Decimal(String("2.71828")) + var result2 = exp(approx_e) + var expected2 = String( + "15.154234532556727211057207398340" + ) # e^(~e) to 27 decimal places + testing.assert_true( + String(result2).startswith(expected2[0:25]), + "e^(~e) should be approximately " + + String(expected2) + + ", got " + + String(result2), + ) + + print("✓ High precision exponents tests passed!") + + +fn test_mathematical_identities() raises: + """Test mathematical identities related to the exponential function.""" + print("Testing mathematical identities for exponential function...") + + # Test case 1: e^(a+b) = e^a * e^b + var a = Decimal(String("2")) + var b = Decimal(String("3")) + var exp_a_plus_b = exp(a + b) + var exp_a_times_exp_b = exp(a) * exp(b) + + # Compare with some level of precision to account for computational differences + var diff1 = abs(exp_a_plus_b - exp_a_times_exp_b) + var rel_diff1 = diff1 / exp_a_plus_b + testing.assert_true( + rel_diff1 < Decimal(String("0.0000001")), + "e^(a+b) should equal e^a * e^b within tolerance, difference: " + + String(rel_diff1), + ) + + # Test case 2: e^(-x) = 1/e^x + var x = Decimal(String("1.5")) + var exp_neg_x = exp(-x) + var one_over_exp_x = Decimal(String("1")) / exp(x) + + # Compare with some level of precision + var diff2 = abs(exp_neg_x - one_over_exp_x) + var rel_diff2 = diff2 / exp_neg_x + testing.assert_true( + rel_diff2 < Decimal(String("0.0000001")), + "e^(-x) should equal 1/e^x within tolerance, difference: " + + String(rel_diff2), + ) + + # Test case 3: e^0 = 1 (Already tested in basic values, but included here for completeness) + var zero = Decimal(String("0")) + var exp_zero = exp(zero) + testing.assert_equal(String(exp_zero), String("1"), "e^0 should equal 1") + + print("✓ Mathematical identities tests passed!") + + +fn test_extreme_values() raises: + """Test exponential function with extreme values.""" + print("Testing exponential function with extreme values...") + + # Test case 1: Very small positive input + var small_input = Decimal(String("0.0000001")) + var result1 = exp(small_input) + testing.assert_true( + String(result1).startswith(String("1.0000001")), + "e^0.0000001 should be approximately 1.0000001, got " + String(result1), + ) + + # Test case 2: Very small negative input + var small_neg_input = Decimal(String("-0.0000001")) + var result2 = exp(small_neg_input) + testing.assert_true( + String(result2).startswith(String("0.9999999")), + "e^-0.0000001 should be approximately 0.9999999, got " + + String(result2), + ) + + # Test case 3: Large positive value + # This should not overflow but should produce a very large result + # Note: The implementation may have specific limits + var large_input = Decimal(String("20")) + var result3 = exp(large_input) + testing.assert_true( + result3 > Decimal(String("100000000")), + "e^20 should be a very large number > 100,000,000, got " + + String(result3), + ) + + print("✓ Extreme values tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for exponential function.""" + print("Testing edge cases for exponential function...") + + # Test with very high precision input + var high_precision = Decimal(String("1.23456789012345678901234567")) + var result_high = exp(high_precision) + testing.assert_true( + len(String(result_high)) > 15, + "Exp with high precision input should produce high precision output", + ) + + print("✓ Edge cases tests passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Exponential Function Tests") + print("=========================================") + + run_test_with_error_handling( + test_basic_exp_values, "Basic exponential values test" + ) + run_test_with_error_handling( + test_negative_exponents, "Negative exponents test" + ) + run_test_with_error_handling( + test_fractional_exponents, "Fractional exponents test" + ) + run_test_with_error_handling( + test_high_precision_exponents, "High precision exponents test" + ) + run_test_with_error_handling( + test_mathematical_identities, "Mathematical identities test" + ) + run_test_with_error_handling(test_extreme_values, "Extreme values test") + run_test_with_error_handling(test_edge_cases, "Edge cases test") + + print("All exponential function tests passed!") diff --git a/tests/test_factorial.mojo b/tests/test_factorial.mojo new file mode 100644 index 0000000..55d3d7b --- /dev/null +++ b/tests/test_factorial.mojo @@ -0,0 +1,278 @@ +""" +Comprehensive tests for the `factorial()` and the `factorial_reciprocal()` +functions in the DeciMojo library. +Tests various cases including edge cases and error handling for factorials +in the range 0 to 27, which is the maximum range supported by Decimal. +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode +from decimojo.special import factorial, factorial_reciprocal + + +fn test_basic_factorials() raises: + """Test basic factorial calculations.""" + print("Testing basic factorial calculations...") + + # Test case 1: 0! = 1 + var result0 = factorial(0) + testing.assert_equal( + String(result0), "1", "0! should be 1, got " + String(result0) + ) + + # Test case 2: 1! = 1 + var result1 = factorial(1) + testing.assert_equal( + String(result1), "1", "1! should be 1, got " + String(result1) + ) + + # Test case 3: 2! = 2 + var result2 = factorial(2) + testing.assert_equal( + String(result2), "2", "2! should be 2, got " + String(result2) + ) + + # Test case 4: 3! = 6 + var result3 = factorial(3) + testing.assert_equal( + String(result3), "6", "3! should be 6, got " + String(result3) + ) + + # Test case 5: 4! = 24 + var result4 = factorial(4) + testing.assert_equal( + String(result4), "24", "4! should be 24, got " + String(result4) + ) + + # Test case 6: 5! = 120 + var result5 = factorial(5) + testing.assert_equal( + String(result5), "120", "5! should be 120, got " + String(result5) + ) + + print("✓ Basic factorial tests passed!") + + +fn test_medium_factorials() raises: + """Test medium-sized factorial calculations.""" + print("Testing medium-sized factorial calculations...") + + # Test case 7: 6! = 720 + var result6 = factorial(6) + testing.assert_equal( + String(result6), "720", "6! should be 720, got " + String(result6) + ) + + # Test case 8: 7! = 5040 + var result7 = factorial(7) + testing.assert_equal( + String(result7), "5040", "7! should be 5040, got " + String(result7) + ) + + # Test case 9: 8! = 40320 + var result8 = factorial(8) + testing.assert_equal( + String(result8), "40320", "8! should be 40320, got " + String(result8) + ) + + # Test case 10: 9! = 362880 + var result9 = factorial(9) + testing.assert_equal( + String(result9), + "362880", + "9! should be 362880, got " + String(result9), + ) + + # Test case 11: 10! = 3628800 + var result10 = factorial(10) + testing.assert_equal( + String(result10), + "3628800", + "10! should be 3628800, got " + String(result10), + ) + + print("✓ Medium factorial tests passed!") + + +fn test_large_factorials() raises: + """Test large factorial calculations.""" + print("Testing large factorial calculations...") + + # Test case 12: 12! = 479001600 + var result12 = factorial(12) + testing.assert_equal( + String(result12), + "479001600", + "12! should be 479001600, got " + String(result12), + ) + + # Test case 13: 15! = 1307674368000 + var result15 = factorial(15) + testing.assert_equal( + String(result15), + "1307674368000", + "15! should be 1307674368000, got " + String(result15), + ) + + # Test case 14: 20! = 2432902008176640000 + var result20 = factorial(20) + testing.assert_equal( + String(result20), + "2432902008176640000", + "20! should be 2432902008176640000, got " + String(result20), + ) + + # Test case 15: 25! + var result25 = factorial(25) + var expected25 = "15511210043330985984000000" + testing.assert_equal( + String(result25), + expected25, + "25! should be " + expected25 + ", got " + String(result25), + ) + + # Test maximum supported factorial: 27! + var result27 = factorial(27) + var expected27 = "10888869450418352160768000000" + testing.assert_equal( + String(result27), + expected27, + "27! should be " + expected27 + ", got " + String(result27), + ) + + print("✓ Large factorial tests passed!") + + +fn test_factorial_properties() raises: + """Test mathematical properties of factorials.""" + print("Testing factorial mathematical properties...") + + # Test case: (n+1)! = (n+1) * n! + # Only test up to 26 because 27 is our maximum supported value + for n in range(0, 26): + var n_fact = factorial(n) + var n_plus_1_fact = factorial(n + 1) + var calculated = n_fact * Decimal(String(n + 1)) + testing.assert_equal( + String(n_plus_1_fact), + String(calculated), + "Property (n+1)! = (n+1)*n! failed for n=" + + String(n) + + "n+1=" + + String(n + 1), + ) + + print("✓ Factorial properties tests passed!") + + +fn test_factorial_edge_cases() raises: + """Test edge cases for factorial function.""" + print("Testing factorial edge cases...") + + # Test case: Error for negative input + var exception_caught = False + try: + var _f1 = factorial(-1) + testing.assert_equal( + True, False, "factorial() of negative should raise exception" + ) + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + # Test case: Error for input > 27 + exception_caught = False + try: + var _f28 = factorial(28) + testing.assert_equal( + True, + False, + "factorial(28) should raise exception (exceeds maximum)", + ) + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + print("✓ Factorial edge case tests passed!") + + +fn test_factorial_of_zero() raises: + """Special test for factorial of zero.""" + print("Testing special case: 0!...") + + # Test case: Verify 0! = 1 (mathematical definition) + var result = factorial(0) + testing.assert_equal(String(result), "1", "0! should equal 1") + + print("✓ Special case test for 0! passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn test_factorial_reciprocal() raises: + """Test that factorial_reciprocal equals 1 divided by factorial.""" + print("Testing factorial_reciprocal function...") + + # Test for all values in the supported range (0-27) + var all_equal = True + for i in range(28): + var a = Decimal(1) / factorial(i) + var b = factorial_reciprocal(i) + + var equal = a == b + if not equal: + all_equal = False + print("Mismatch at " + String(i) + ":") + print(" 1/" + String(i) + "! = " + String(a)) + print(" reciprocal = " + String(b)) + + testing.assert_true( + all_equal, + ( + "factorial_reciprocal(n) should equal Decimal(1)/factorial(n) for" + " all n" + ), + ) + + print("✓ Factorial reciprocal tests passed!") + + +fn main() raises: + print("=========================================") + print("Running Factorial Function Tests (0-27)") + print("=========================================") + + run_test_with_error_handling(test_basic_factorials, "Basic factorials test") + run_test_with_error_handling( + test_medium_factorials, "Medium factorials test" + ) + run_test_with_error_handling(test_large_factorials, "Large factorials test") + run_test_with_error_handling( + test_factorial_properties, "Factorial properties test" + ) + run_test_with_error_handling( + test_factorial_edge_cases, "Factorial edge cases test" + ) + run_test_with_error_handling( + test_factorial_of_zero, "Factorial of zero test" + ) + run_test_with_error_handling( + test_factorial_reciprocal, "Factorial reciprocal test" + ) + + print("All factorial function tests passed!")