diff --git a/.gitignore b/.gitignore index fa5c872..6841c79 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ tempCodeRunnerFile.mojo # log files *.log # local files -/test*.mojo \ No newline at end of file +/test*.mojo +local \ No newline at end of file diff --git a/benches/bench_exp.mojo b/benches/bench_exp.mojo index a669eda..148d5f4 100644 --- a/benches/bench_exp.mojo +++ b/benches/bench_exp.mojo @@ -71,7 +71,6 @@ fn run_benchmark( var mojo_decimal = Decimal(input_value) var pydecimal = Python.import_module("decimal") var py_decimal = pydecimal.Decimal(input_value) - var _py_math = Python.import_module("math") # Execute the operations once to verify correctness var mojo_result = dm.exponential.exp(mojo_decimal) diff --git a/benches/bench_ln.mojo b/benches/bench_ln.mojo new file mode 100644 index 0000000..7ce4674 --- /dev/null +++ b/benches/bench_ln.mojo @@ -0,0 +1,369 @@ +""" +Comprehensive benchmarks for Decimal natural logarithm function (ln). +Compares performance against Python's decimal module with 20 diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_ln_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark( + name: String, + input_value: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal ln with Python Decimal ln. + + Args: + name: Name of the benchmark case. + input_value: String representation of value for ln(x). + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Input value: " + input_value, log_file) + + # Set up Mojo and Python values + var mojo_decimal = Decimal(input_value) + var pydecimal = Python.import_module("decimal") + var py_decimal = pydecimal.Decimal(input_value) + + # Execute the operations once to verify correctness + var mojo_result = dm.exponential.ln(mojo_decimal) + var py_result = py_decimal.ln() + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = dm.exponential.ln(mojo_decimal) + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_decimal.ln() + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo ln(): " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python ln(): " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print( + "=== DeciMojo Natural Logarithm Function (ln) Benchmark ===", log_file + ) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 100 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning natural logarithm function benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: ln(1) = 0 + run_benchmark( + "ln(1) = 0", + "1", + iterations, + log_file, + speedup_factors, + ) + + # Case 2: ln(e) ≈ 1 + run_benchmark( + "ln(e) ≈ 1", + "2.718281828459045235360287471", + iterations, + log_file, + speedup_factors, + ) + + # Case 3: ln(2) + run_benchmark( + "ln(2)", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 4: ln(10) + run_benchmark( + "ln(10)", + "10", + iterations, + log_file, + speedup_factors, + ) + + # Case 5: ln(0.5) + run_benchmark( + "ln(0.5)", + "0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 6: ln(5) + run_benchmark( + "ln(5)", + "5", + iterations, + log_file, + speedup_factors, + ) + + # Case 7: ln with small positive value + run_benchmark( + "Small positive value", + "1.0001", + iterations, + log_file, + speedup_factors, + ) + + # Case 8: ln with very small positive value + run_benchmark( + "Very small positive value", + "1.000000001", + iterations, + log_file, + speedup_factors, + ) + + # Case 9: ln with value slightly less than 1 + run_benchmark( + "Value slightly less than 1", + "0.9999", + iterations, + log_file, + speedup_factors, + ) + + # Case 10: ln with value slightly greater than 1 + run_benchmark( + "Value slightly greater than 1", + "1.0001", + iterations, + log_file, + speedup_factors, + ) + + # Case 11: ln with moderate value + run_benchmark( + "Moderate value", + "7.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 12: ln with large value + run_benchmark( + "Large value", + "1000", + iterations, + log_file, + speedup_factors, + ) + + # Case 13: ln with very large value + run_benchmark( + "Very large value", + "1000000000", + iterations, + log_file, + speedup_factors, + ) + + # Case 14: ln with high precision input + run_benchmark( + "High precision input", + "2.718281828459045235360287471", + iterations, + log_file, + speedup_factors, + ) + + # Case 15: ln with fractional value + run_benchmark( + "Fractional value", + "0.25", + iterations, + log_file, + speedup_factors, + ) + + # Case 16: ln with fractional value of many digits + run_benchmark( + "Fractional value with many digits", + "0.12345678901234567890123456789", + iterations, + log_file, + speedup_factors, + ) + + # Case 17: ln with approximate e value + run_benchmark( + "Approximate e value", + "2.718", + iterations, + log_file, + speedup_factors, + ) + + # Case 18: ln with larger value + run_benchmark( + "Larger value", + "150", + iterations, + log_file, + speedup_factors, + ) + + # Case 19: ln with value between 0 and 1 + run_benchmark( + "Value between 0 and 1", + "0.75", + iterations, + log_file, + speedup_factors, + ) + + # Case 20: ln with value close to zero + run_benchmark( + "Value close to zero", + "0.00001", + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print( + "\n=== Natural Logarithm Function Benchmark Summary ===", log_file + ) + log_print("Benchmarked: 20 different ln() cases", log_file) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print("Average speedup: " + String(average_speedup) + "×", log_file) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)), + log_file, + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/mojoproject.toml b/mojoproject.toml index 9f2f1b7..ef3a35a 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -34,7 +34,7 @@ test = "magic run package && magic run mojo test tests && magic run delete_packa t = "clear && magic run test" test_arith = "magic run package && magic run mojo test tests/test_arithmetics.mojo && magic run delete_package" test_multiply = "magic run package && magic run mojo test tests/test_multiply.mojo && magic run delete_package" -test_divide = "magic run package && magic run mojo test tests/test_division.mojo && magic run delete_package" +test_divide = "magic run package && magic run mojo test tests/test_divide.mojo && magic run delete_package" test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && magic run delete_package" test_round = "magic run package && magic run mojo test tests/test_round.mojo && magic run delete_package" test_creation = "magic run package && magic run mojo test tests/test_creation.mojo && magic run delete_package" @@ -44,6 +44,7 @@ test_to_float = "magic run package && magic run mojo test tests/test_to_float.mo test_comparison = "magic run package && magic run mojo test tests/test_comparison.mojo && magic run delete_package" test_factorial = "magic run package && magic run mojo test tests/test_factorial.mojo && magic run delete_package" test_exp = "magic run package && magic run mojo test tests/test_exp.mojo && magic run delete_package" +test_ln = "magic run package && magic run mojo test tests/test_ln.mojo && magic run delete_package" # benches bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" @@ -56,6 +57,7 @@ bench_from_float = "magic run package && cd benches && magic run mojo bench_from bench_from_string = "magic run package && cd benches && magic run mojo bench_from_string.mojo && cd .. && magic run delete_package" bench_comparison = "magic run package && cd benches && magic run mojo bench_comparison.mojo && cd .. && magic run delete_package" bench_exp = "magic run package && cd benches && magic run mojo bench_exp.mojo && cd .. && magic run delete_package" +bench_ln = "magic run package && cd benches && magic run mojo bench_ln.mojo && cd .. && magic run delete_package" # before commit final = "magic run test && magic run bench" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 910ea02..178149e 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -49,7 +49,7 @@ from .comparison import ( not_equal, ) -from .exponential import power, sqrt, exp +from .exponential import power, sqrt, exp, ln from .rounding import round diff --git a/src/decimojo/arithmetics.mojo b/src/decimojo/arithmetics.mojo index 2399ed3..fa2ad39 100644 --- a/src/decimojo/arithmetics.mojo +++ b/src/decimojo/arithmetics.mojo @@ -125,12 +125,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: if summation > Decimal.MAX_AS_UINT128: # 2^96-1 raise Error("Error in `addition()`: Decimal overflow") - # Extract the 32-bit components from the UInt128 sum - var low = UInt32(summation & 0xFFFFFFFF) - var mid = UInt32((summation >> 32) & 0xFFFFFFFF) - var high = UInt32((summation >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, 0, x1.is_negative()) + return Decimal.from_uint128(summation, 0, x1.is_negative()) # Different signs: subtract the smaller from the larger else: @@ -143,12 +138,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: diff = x2_coef - x1_coef is_negative = x2.is_negative() - # Extract the 32-bit components from the UInt128 difference - low = UInt32(diff & 0xFFFFFFFF) - mid = UInt32((diff >> 32) & 0xFFFFFFFF) - high = UInt32((diff >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, 0, is_negative) + return Decimal.from_uint128(diff, 0, is_negative) # CASE: Integer addition with positive scales elif x1.is_integer() and x2.is_integer(): @@ -173,12 +163,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: scale -= 1 summation *= UInt128(10) ** scale - # Extract the 32-bit components from the UInt128 sum - var low = UInt32(summation & 0xFFFFFFFF) - var mid = UInt32((summation >> 32) & 0xFFFFFFFF) - var high = UInt32((summation >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, scale, x1.is_negative()) + return Decimal.from_uint128(summation, scale, x1.is_negative()) # Different signs: subtract the smaller from the larger else: @@ -202,12 +187,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: scale -= 1 diff *= UInt128(10) ** scale - # Extract the 32-bit components from the UInt128 difference - low = UInt32(diff & 0xFFFFFFFF) - mid = UInt32((diff >> 32) & 0xFFFFFFFF) - high = UInt32((diff >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, scale, is_negative) + return Decimal.from_uint128(diff, scale, is_negative) # CASE: Float addition with the same scale elif x1_scale == x2_scale: @@ -273,7 +253,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: summation = x1_coef_scaled - x2_coef_scaled is_negative = x1.is_negative() elif x1_coef_scaled < x2_coef_scaled: - summation = x2_coef_scaled * x1_coef_scaled + summation = x2_coef_scaled - x1_coef_scaled is_negative = x2.is_negative() else: return Decimal.from_uint128(UInt128(0), x1_scale, False) @@ -478,15 +458,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: prod, num_digits_to_keep ) var final_scale = min(Decimal.MAX_SCALE, combined_scale) - var low = UInt32(truncated_prod & 0xFFFFFFFF) - var mid = UInt32((truncated_prod >> 32) & 0xFFFFFFFF) - var high = UInt32((truncated_prod >> 64) & 0xFFFFFFFF) - return Decimal( - low, - mid, - high, - final_scale, - is_negative, + return Decimal.from_uint128( + truncated_prod, final_scale, is_negative ) # SPECIAL CASE: Second operand has coefficient of 1 @@ -509,15 +482,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: prod, num_digits_to_keep ) var final_scale = min(Decimal.MAX_SCALE, combined_scale) - var low = UInt32(truncated_prod & 0xFFFFFFFF) - var mid = UInt32((truncated_prod >> 32) & 0xFFFFFFFF) - var high = UInt32((truncated_prod >> 64) & 0xFFFFFFFF) - return Decimal( - low, - mid, - high, - final_scale, - is_negative, + return Decimal.from_uint128( + truncated_prod, final_scale, is_negative ) # Determine the number of bits in the coefficients @@ -529,6 +495,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # SPECIAL CASE: Both operands are true integers if x1_scale == 0 and x2_scale == 0: + print("DEBUG: Both operands are true integers") + print("DEBUG: combined_num_bits: ", combined_num_bits) # Small integers, use UInt64 multiplication if combined_num_bits <= 64: var prod: UInt64 = UInt64(x1_coef) * UInt64(x2_coef) @@ -539,21 +507,21 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Moderate integers, use UInt128 multiplication elif combined_num_bits <= 128: var prod: UInt128 = UInt128(x1_coef) * UInt128(x2_coef) - var low = UInt32(prod & 0xFFFFFFFF) - var mid = UInt32((prod >> 32) & 0xFFFFFFFF) - var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, 0, is_negative) + if prod > Decimal.MAX_AS_UINT128: + raise Error( + "Error in `multiply()`: The product is {}, which exceeds" + " the capacity of Decimal (2^96-1)".format(prod) + ) + else: + return Decimal.from_uint128(prod, 0, is_negative) - # Large integers, use UInt256 multiplication + # Large integers, it will definitely overflow else: var prod: UInt256 = UInt256(x1_coef) * UInt256(x2_coef) - if prod > Decimal.MAX_AS_UINT256: - raise Error("Error in `prodtiply()`: Decimal overflow") - else: - var low = UInt32(prod & 0xFFFFFFFF) - var mid = UInt32((prod >> 32) & 0xFFFFFFFF) - var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, 0, is_negative) + raise Error( + "Error in `multiply()`: The product is {}, which exceeds the" + " capacity of Decimal (2^96-1)".format(prod) + ) # SPECIAL CASE: Both operands are integers but with scales # Examples: 123.0 * 456.00 @@ -599,10 +567,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Combined scale more than max precision, no need to truncate if combined_scale <= Decimal.MAX_SCALE: - var low = UInt32(prod & 0xFFFFFFFF) - var mid = UInt32((prod >> 32) & 0xFFFFFFFF) - var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, combined_scale, is_negative) + return Decimal.from_uint128(prod, combined_scale, is_negative) # Combined scale no more than max precision, truncate with rounding else: @@ -622,11 +587,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: ) final_scale = Decimal.MAX_SCALE - var low = UInt32(prod & 0xFFFFFFFF) - var mid = UInt32((prod >> 32) & 0xFFFFFFFF) - var high = UInt32((prod >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, final_scale, is_negative) + return Decimal.from_uint128(prod, final_scale, is_negative) # SUB-CASE: Both operands are moderate # The bits of the product will not exceed 128 bits @@ -679,12 +640,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: ) final_scale = Decimal.MAX_SCALE - # Extract the 32-bit components from the UInt128 product - var low = UInt32(prod & 0xFFFFFFFF) - var mid = UInt32((prod >> 32) & 0xFFFFFFFF) - var high = UInt32((prod >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, final_scale, is_negative) + return Decimal.from_uint128(prod, final_scale, is_negative) # REMAINING CASES: Both operands are big # The bits of the product will not exceed 192 bits @@ -828,11 +784,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: < Decimal.MAX_NUM_DIGITS ): var quot = x1_coef * UInt128(10) ** (-diff_scale) - # print("DEBUG: quot", quot) - var low = UInt32(quot & 0xFFFFFFFF) - var mid = UInt32((quot >> 32) & 0xFFFFFFFF) - var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, 0, is_negative) + return Decimal.from_uint128(quot, 0, is_negative) # If the result should be stored in UInt256 else: @@ -866,10 +818,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # Since -diff_scale is less than 28, the result would not overflow else: var quot = UInt128(1) * UInt128(10) ** (-diff_scale) - var low = UInt32(quot & 0xFFFFFFFF) - var mid = UInt32((quot >> 32) & 0xFFFFFFFF) - var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, 0, is_negative) + return Decimal.from_uint128(quot, 0, is_negative) # SPECIAL CASE: Modulus of coefficients is zero (exact division) # 特例: 係數的餘數爲零 (可除盡) @@ -885,10 +834,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # High will be zero because the quotient is less than 2^48 # For safety, we still calcuate the high word var quot = x1_coef // x2_coef - var low = UInt32(quot & 0xFFFFFFFF) - var mid = UInt32((quot >> 32) & 0xFFFFFFFF) - var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, diff_scale, is_negative) + return Decimal.from_uint128(quot, diff_scale, is_negative) else: # If diff_scale < 0, return the quotient with scaling up @@ -902,10 +848,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: < Decimal.MAX_NUM_DIGITS ): var quot = quot * UInt128(10) ** (-diff_scale) - var low = UInt32(quot & 0xFFFFFFFF) - var mid = UInt32((quot >> 32) & 0xFFFFFFFF) - var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, 0, is_negative) + return Decimal.from_uint128(quot, 0, is_negative) # If the result should be stored in UInt256 else: @@ -1052,11 +995,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: ) scale_of_quot = Decimal.MAX_SCALE - var low = UInt32(quot & 0xFFFFFFFF) - var mid = UInt32((quot >> 32) & 0xFFFFFFFF) - var high = UInt32((quot >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, scale_of_quot, is_negative) + return Decimal.from_uint128(quot, scale_of_quot, is_negative) # Otherwise, we need to truncate the first 29 or 28 digits else: @@ -1084,11 +1023,9 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: ) scale_of_truncated_quot = Decimal.MAX_SCALE - var low = UInt32(truncated_quot & 0xFFFFFFFF) - var mid = UInt32((truncated_quot >> 32) & 0xFFFFFFFF) - var high = UInt32((truncated_quot >> 64) & 0xFFFFFFFF) - - return Decimal(low, mid, high, scale_of_truncated_quot, is_negative) + return Decimal.from_uint128( + truncated_quot, scale_of_truncated_quot, is_negative + ) # SUB-CASE: Use UInt256 to store the quotient # Also the FALLBACK approach for the remaining cases diff --git a/src/decimojo/constants.mojo b/src/decimojo/constants.mojo new file mode 100644 index 0000000..738a98f --- /dev/null +++ b/src/decimojo/constants.mojo @@ -0,0 +1,609 @@ +# ===----------------------------------------------------------------------=== # +# DeciMojo: A fixed-point decimal arithmetic library in Mojo +# https://github.com/forfudan/decimojo +# +# Copyright 2025 Yuhao Zhu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +# +# Useful constants for Decimal type +# +# ===----------------------------------------------------------------------=== # + +"""Useful constants for Decimal type.""" + +# ===----------------------------------------------------------------------=== # +# +# Integer constants +# The prefix "M" stands for a decimal (money) value. +# This is a convention in C. +# +# ===----------------------------------------------------------------------=== # + + +@always_inline +fn M0() -> Decimal: + """Returns 0 as a Decimal.""" + return Decimal(0x0, 0x0, 0x0, 0x0) + + +@always_inline +fn M1() -> Decimal: + """Returns 1 as a Decimal.""" + return Decimal(0x1, 0x0, 0x0, 0x0) + + +@always_inline +fn M2() -> Decimal: + """Returns 2 as a Decimal.""" + return Decimal(0x2, 0x0, 0x0, 0x0) + + +@always_inline +fn M3() -> Decimal: + """Returns 3 as a Decimal.""" + return Decimal(0x3, 0x0, 0x0, 0x0) + + +@always_inline +fn M4() -> Decimal: + """Returns 4 as a Decimal.""" + return Decimal(0x4, 0x0, 0x0, 0x0) + + +@always_inline +fn M5() -> Decimal: + """Returns 5 as a Decimal.""" + return Decimal(0x5, 0x0, 0x0, 0x0) + + +@always_inline +fn M6() -> Decimal: + """Returns 6 as a Decimal.""" + return Decimal(0x6, 0x0, 0x0, 0x0) + + +@always_inline +fn M7() -> Decimal: + """Returns 7 as a Decimal.""" + return Decimal(0x7, 0x0, 0x0, 0x0) + + +@always_inline +fn M8() -> Decimal: + """Returns 8 as a Decimal.""" + return Decimal(0x8, 0x0, 0x0, 0x0) + + +@always_inline +fn M9() -> Decimal: + """Returns 9 as a Decimal.""" + return Decimal(0x9, 0x0, 0x0, 0x0) + + +@always_inline +fn M10() -> Decimal: + """Returns 10 as a Decimal.""" + return Decimal(0xA, 0x0, 0x0, 0x0) + + +# ===----------------------------------------------------------------------=== # +# +# Inverse constants +# +# ===----------------------------------------------------------------------=== # + + +@always_inline +fn INV2() -> Decimal: + """Returns 1/2 = 0.5.""" + return Decimal(0x5, 0x0, 0x0, 0x10000) + + +@always_inline +fn INV10() -> Decimal: + """Returns 1/10 = 0.1.""" + return Decimal(0x1, 0x0, 0x0, 0x10000) + + +@always_inline +fn INV0D1() -> Decimal: + """Returns 1/0.1 = 10.""" + return Decimal(0xA, 0x0, 0x0, 0x0) + + +@always_inline +fn INV0D2() -> Decimal: + """Returns 1/0.2 = 5.""" + return Decimal(0x5, 0x0, 0x0, 0x0) + + +@always_inline +fn INV0D3() -> Decimal: + """Returns 1/0.3 = 3.33333333333333333333333333333333...""" + return Decimal(0x35555555, 0xCF2607EE, 0x6BB4AFE4, 0x1C0000) + + +@always_inline +fn INV0D4() -> Decimal: + """Returns 1/0.4 = 2.5.""" + return Decimal(0x19, 0x0, 0x0, 0x10000) + + +@always_inline +fn INV0D5() -> Decimal: + """Returns 1/0.5 = 2.""" + return Decimal(0x2, 0x0, 0x0, 0x0) + + +@always_inline +fn INV0D6() -> Decimal: + """Returns 1/0.6 = 1.66666666666666666666666666666667...""" + return Decimal(0x1AAAAAAB, 0x679303F7, 0x35DA57F2, 0x1C0000) + + +@always_inline +fn INV0D7() -> Decimal: + """Returns 1/0.7 = 1.42857142857142857142857142857143...""" + return Decimal(0xCDB6DB6E, 0x3434DED3, 0x2E28DDAB, 0x1C0000) + + +@always_inline +fn INV0D8() -> Decimal: + """Returns 1/0.8 = 1.25.""" + return Decimal(0x7D, 0x0, 0x0, 0x20000) + + +@always_inline +fn INV0D9() -> Decimal: + """Returns 1/0.9 = 1.11111111111111111111111111111111...""" + return Decimal(0x671C71C7, 0x450CAD4F, 0x23E6E54C, 0x1C0000) + + +@always_inline +fn INV1() -> Decimal: + """Returns 1/1 = 1.""" + return Decimal(0x1, 0x0, 0x0, 0x0) + + +@always_inline +fn INV1D1() -> Decimal: + """Returns 1/1.1 = 0.90909090909090909090909090909091...""" + return Decimal(0x9A2E8BA3, 0x4FC48DCC, 0x1D5FD2E1, 0x1C0000) + + +@always_inline +fn INV1D2() -> Decimal: + """Returns 1/1.2 = 0.83333333333333333333333333333333...""" + return Decimal(0x8D555555, 0x33C981FB, 0x1AED2BF9, 0x1C0000) + + +@always_inline +fn INV1D3() -> Decimal: + """Returns 1/1.3 = 0.76923076923076923076923076923077...""" + return Decimal(0xC4EC4EC, 0x9243DA72, 0x18DAED83, 0x1C0000) + + +@always_inline +fn INV1D4() -> Decimal: + """Returns 1/1.4 = 0.71428571428571428571428571428571...""" + return Decimal(0xE6DB6DB7, 0x9A1A6F69, 0x17146ED5, 0x1C0000) + + +@always_inline +fn INV1D5() -> Decimal: + """Returns 1/1.5 = 0.66666666666666666666666666666667...""" + return Decimal(0xAAAAAAB, 0x296E0196, 0x158A8994, 0x1C0000) + + +@always_inline +fn INV1D6() -> Decimal: + """Returns 1/1.6 = 0.625.""" + return Decimal(0x271, 0x0, 0x0, 0x30000) + + +@always_inline +fn INV1D7() -> Decimal: + """Returns 1/1.7 = 0.58823529411764705882352941176471...""" + return Decimal(0x45A5A5A6, 0xE8520166, 0x1301C4AF, 0x1C0000) + + +@always_inline +fn INV1D8() -> Decimal: + """Returns 1/1.8 = 0.55555555555555555555555555555556...""" + return Decimal(0xB38E38E4, 0x228656A7, 0x11F372A6, 0x1C0000) + + +@always_inline +fn INV1D9() -> Decimal: + """Returns 1/1.9 = 0.52631578947368421052631578947368...""" + return Decimal(0xAA1AF287, 0x2E2E6D0A, 0x11019509, 0x1C0000) + + +# ===----------------------------------------------------------------------=== # +# +# N / (N+1) constants +# +# ===----------------------------------------------------------------------=== # + + +@always_inline +fn N_DIVIDE_NEXT(n: Int) raises -> Decimal: + """ + Returns the pre-calculated value of n/(n+1) for n between 1 and 20. + + Args: + n: An integer between 1 and 20, inclusive. + + Returns: + A Decimal representing the value of n/(n+1). + + Raises: + Error: If n is outside the range [1, 20]. + """ + if n == 1: + # 1/2 = 0.5 + return Decimal(0x5, 0x0, 0x0, 0x10000) + elif n == 2: + # 2/3 = 0.66666666666666666666666666666667... + return Decimal(0xAAAAAAB, 0x296E0196, 0x158A8994, 0x1C0000) + elif n == 3: + # 3/4 = 0.75 + return Decimal(0x4B, 0x0, 0x0, 0x20000) + elif n == 4: + # 4/5 = 0.8 + return Decimal(0x8, 0x0, 0x0, 0x10000) + elif n == 5: + # 5/6 = 0.83333333333333333333333333333333... + return Decimal(0x8D555555, 0x33C981FB, 0x1AED2BF9, 0x1C0000) + elif n == 6: + # 6/7 = 0.85714285714285714285714285714286... + return Decimal(0x7B6DB6DB, 0xEC1FB8E5, 0x1BB21E99, 0x1C0000) + elif n == 7: + # 7/8 = 0.875 + return Decimal(0x36B, 0x0, 0x0, 0x30000) + elif n == 8: + # 8/9 = 0.88888888888888888888888888888889... + return Decimal(0xB8E38E39, 0x373D5772, 0x1CB8B770, 0x1C0000) + elif n == 9: + # 9/10 = 0.9 + return Decimal(0x9, 0x0, 0x0, 0x10000) + elif n == 10: + # 10/11 = 0.90909090909090909090909090909091... + return Decimal(0x9A2E8BA3, 0x4FC48DCC, 0x1D5FD2E1, 0x1C0000) + elif n == 11: + # 11/12 = 0.91666666666666666666666666666667... + return Decimal(0x4EAAAAAB, 0xB8F7422E, 0x1D9E7D2B, 0x1C0000) + elif n == 12: + # 12/13 = 0.92307692307692307692307692307692... + return Decimal(0xEC4EC4F, 0xAF849FBC, 0x1DD3836A, 0x1C0000) + elif n == 13: + # 13/14 = 0.92857142857142857142857142857143... + return Decimal(0x45B6DB6E, 0x15225DA3, 0x1E00F67C, 0x1C0000) + elif n == 14: + # 14/15 = 0.93333333333333333333333333333333... + return Decimal(0x75555555, 0xD39A0238, 0x1E285A35, 0x1C0000) + elif n == 15: + # 15/16 = 0.9375 + return Decimal(0x249F, 0x0, 0x0, 0x40000) + elif n == 16: + # 16/17 = 0.94117647058823529411764705882353... + return Decimal(0x3C3C3C3C, 0xD50023D, 0x1E693AB3, 0x1C0000) + elif n == 17: + # 17/18 = 0.94444444444444444444444444444444... + return Decimal(0xE471C71C, 0x3AB12CE9, 0x1E8442E7, 0x1C0000) + elif n == 18: + # 18/19 = 0.94736842105263157894736842105263... + return Decimal(0xCBCA1AF3, 0x1FED2AAC, 0x1E9C72AA, 0x1C0000) + elif n == 19: + # 19/20 = 0.95 + return Decimal(0x5F, 0x0, 0x0, 0x20000) + elif n == 20: + # 20/21 = 0.95238095238095238095238095238095... + return Decimal(0x33CF3CF4, 0xCD78948D, 0x1EC5E91C, 0x1C0000) + else: + raise Error("N_DIVIDE_NEXT: n must be between 1 and 20, inclusive") + + +# ===----------------------------------------------------------------------=== # +# +# PI constants +# +# ===----------------------------------------------------------------------=== # + + +@always_inline +fn PI() -> Decimal: + """Returns the value of pi (π) as a Decimal.""" + return Decimal(0x41B65F29, 0xB143885, 0x6582A536, 0x1C0000) + + +# ===----------------------------------------------------------------------=== # +# +# EXP constants +# +# ===----------------------------------------------------------------------=== # + + +@always_inline +fn E() -> Decimal: + """ + Returns the value of Euler's number (e) as a Decimal. + + Returns: + A Decimal representation of Euler's number with maximum precision. + """ + return Decimal(0x857AED5A, 0xEBECDE35, 0x57D519AB, 0x1C0000) + + +@always_inline +fn E2() -> Decimal: + """Returns the value of e^2 as a Decimal.""" + return Decimal(0xE4DFDCAE, 0x89F7E295, 0xEEC0D6E9, 0x1C0000) + + +@always_inline +fn E3() -> Decimal: + """Returns the value of e^3 as a Decimal.""" + return Decimal(0x236454F7, 0x62055A80, 0x40E65DE2, 0x1B0000) + + +@always_inline +fn E4() -> Decimal: + """Returns the value of e^4 as a Decimal.""" + return Decimal(0x7121EFD3, 0xFB318FB5, 0xB06A87FB, 0x1B0000) + + +@always_inline +fn E5() -> Decimal: + """Returns the value of e^5 as a Decimal.""" + return Decimal(0xD99BD974, 0x9F4BE5C7, 0x2FF472E3, 0x1A0000) + + +@always_inline +fn E6() -> Decimal: + """Returns the value of e^6 as a Decimal.""" + return Decimal(0xADB57A66, 0xBD7A423F, 0x825AD8FF, 0x1A0000) + + +@always_inline +fn E7() -> Decimal: + """Returns the value of e^7 as a Decimal.""" + return Decimal(0x22313FCF, 0x64D5D12F, 0x236F230A, 0x190000) + + +@always_inline +fn E8() -> Decimal: + """Returns the value of e^8 as a Decimal.""" + return Decimal(0x1E892E63, 0xD1BF8B5C, 0x6051E812, 0x190000) + + +@always_inline +fn E9() -> Decimal: + """Returns the value of e^9 as a Decimal.""" + return Decimal(0x34FAB691, 0xE7CD8DEA, 0x1A2EB6C3, 0x180000) + + +@always_inline +fn E10() -> Decimal: + """Returns the value of e^10 as a Decimal.""" + return Decimal(0xBA7F4F65, 0x58692B62, 0x472BDD8F, 0x180000) + + +@always_inline +fn E11() -> Decimal: + """Returns the value of e^11 as a Decimal.""" + return Decimal(0x8C2C6D20, 0x2A86F9E7, 0xC176BAAE, 0x180000) + + +@always_inline +fn E12() -> Decimal: + """Returns the value of e^12 as a Decimal.""" + return Decimal(0xE924992A, 0x31CDC314, 0x3496C2C4, 0x170000) + + +@always_inline +fn E13() -> Decimal: + """Returns the value of e^13 as a Decimal.""" + return Decimal(0x220130DB, 0xC386029A, 0x8EF393FB, 0x170000) + + +@always_inline +fn E14() -> Decimal: + """Returns the value of e^14 as a Decimal.""" + return Decimal(0x3A24795C, 0xC412DF01, 0x26DBB5A0, 0x160000) + + +@always_inline +fn E15() -> Decimal: + """Returns the value of e^15 as a Decimal.""" + return Decimal(0x6C1248BD, 0x90456557, 0x69A0AD8C, 0x160000) + + +@always_inline +fn E16() -> Decimal: + """Returns the value of e^16 as a Decimal.""" + return Decimal(0xB46A97D, 0x90655BBD, 0x1CB66B18, 0x150000) + + +@always_inline +fn E32() -> Decimal: + """Returns the value of e^32 as a Decimal.""" + return Decimal(0x18420EB, 0xCC2501E6, 0xFF24A138, 0xF0000) + + +@always_inline +fn E0D5() -> Decimal: + """Returns the value of e^0.5 = e^(1/2) as a Decimal.""" + return Decimal(0x8E99DD66, 0xC210E35C, 0x3545E717, 0x1C0000) + + +@always_inline +fn E0D25() -> Decimal: + """Returns the value of e^0.25 = e^(1/4) as a Decimal.""" + return Decimal(0xB43646F1, 0x2654858A, 0x297D3595, 0x1C0000) + + +# ===----------------------------------------------------------------------=== # +# +# LN constants +# +# ===----------------------------------------------------------------------=== # + +# The repr of the magic numbers can be obtained by the following code: +# +# ```mojo +# fn print_repr_from_words(value: String, ln_value: String) raises: +# """ +# Prints the hex representation of a logarithm value. +# Args: +# value: The original value (for display purposes). +# ln_value: The natural logarithm as a String. +# """ +# var log_decimal = Decimal(ln_value) +# print("ln(" + value + "): " + log_decimal.repr_from_words()) +# ``` + + +# Constants for integers + + +@always_inline +fn LN1() -> Decimal: + """Returns ln(1) = 0.""" + return Decimal(0x0, 0x0, 0x0, 0x0) + + +@always_inline +fn LN2() -> Decimal: + """Returns ln(2) = 0.69314718055994530941723212145818...""" + return Decimal(0xAA7A65BF, 0x81F52F01, 0x1665943F, 0x1C0000) + + +@always_inline +fn LN10() -> Decimal: + """Returns ln(10) = 2.30258509299404568401799145468436...""" + return Decimal(0x9FA69733, 0x1414B220, 0x4A668998, 0x1C0000) + + +# Constants for values less than 1 +@always_inline +fn LN0D1() -> Decimal: + """Returns ln(0.1) = -2.30258509299404568401799145468436...""" + return Decimal(0x9FA69733, 0x1414B220, 0x4A668998, 0x801C0000) + + +@always_inline +fn LN0D2() -> Decimal: + """Returns ln(0.2) = -1.60943791243410037460075933322619...""" + return Decimal(0xF52C3174, 0x921F831E, 0x3400F558, 0x801C0000) + + +@always_inline +fn LN0D3() -> Decimal: + """Returns ln(0.3) = -1.20397280432593599262274621776184...""" + return Decimal(0x2B8E6822, 0x8258467, 0x26E70795, 0x801C0000) + + +@always_inline +fn LN0D4() -> Decimal: + """Returns ln(0.4) = -0.91629073187415506518352721176801...""" + return Decimal(0x4AB1CBB6, 0x102A541D, 0x1D9B6119, 0x801C0000) + + +@always_inline +fn LN0D5() -> Decimal: + """Returns ln(0.5) = -0.69314718055994530941723212145818...""" + return Decimal(0xAA7A65BF, 0x81F52F01, 0x1665943F, 0x801C0000) + + +@always_inline +fn LN0D6() -> Decimal: + """Returns ln(0.6) = -0.51082562376599068320551409630366...""" + return Decimal(0x81140263, 0x86305565, 0x10817355, 0x801C0000) + + +@always_inline +fn LN0D7() -> Decimal: + """Returns ln(0.7) = -0.35667494393873237891263871124118...""" + return Decimal(0x348BC5A8, 0x8B755D08, 0xB865892, 0x801C0000) + + +@always_inline +fn LN0D8() -> Decimal: + """Returns ln(0.8) = -0.22314355131420975576629509030983...""" + return Decimal(0xA03765F7, 0x8E35251B, 0x735CCD9, 0x801C0000) + + +@always_inline +fn LN0D9() -> Decimal: + """Returns ln(0.9) = -0.10536051565782630122750098083931...""" + return Decimal(0xB7763910, 0xFC3656AD, 0x3678591, 0x801C0000) + + +# Constants for values greater than 1 + + +@always_inline +fn LN1D1() -> Decimal: + """Returns ln(1.1) = 0.09531017980432486004395212328077...""" + return Decimal(0x7212FFD1, 0x7D9A10, 0x3146328, 0x1C0000) + + +@always_inline +fn LN1D2() -> Decimal: + """Returns ln(1.2) = 0.18232155679395462621171802515451...""" + return Decimal(0x2966635C, 0xFBC4D99C, 0x5E420E9, 0x1C0000) + + +@always_inline +fn LN1D3() -> Decimal: + """Returns ln(1.3) = 0.26236426446749105203549598688095...""" + return Decimal(0xE0BE71FD, 0xC254E078, 0x87A39F0, 0x1C0000) + + +@always_inline +fn LN1D4() -> Decimal: + """Returns ln(1.4) = 0.33647223662121293050459341021699...""" + return Decimal(0x75EEA016, 0xF67FD1F9, 0xADF3BAC, 0x1C0000) + + +@always_inline +fn LN1D5() -> Decimal: + """Returns ln(1.5) = 0.40546510810816438197801311546435...""" + return Decimal(0xC99DC953, 0x89F9FEB7, 0xD19EDC3, 0x1C0000) + + +@always_inline +fn LN1D6() -> Decimal: + """Returns ln(1.6) = 0.47000362924573555365093703114834...""" + return Decimal(0xA42FFC8, 0xF3C009E6, 0xF2FC765, 0x1C0000) + + +@always_inline +fn LN1D7() -> Decimal: + """Returns ln(1.7) = 0.53062825106217039623154316318876...""" + return Decimal(0x64BB9ED0, 0x4AB9978F, 0x11254107, 0x1C0000) + + +@always_inline +fn LN1D8() -> Decimal: + """Returns ln(1.8) = 0.58778666490211900818973114061886...""" + return Decimal(0xF3042CAE, 0x85BED853, 0x12FE0EAD, 0x1C0000) + + +@always_inline +fn LN1D9() -> Decimal: + """Returns ln(1.9) = 0.64185388617239477599103597720349...""" + return Decimal(0x12F992DC, 0xE7374425, 0x14BD4A78, 0x1C0000) diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index b2c0afd..f1a8f59 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -57,6 +57,7 @@ import testing import decimojo.arithmetics import decimojo.comparison +import decimojo.constants import decimojo.exponential import decimojo.rounding from decimojo.rounding_mode import RoundingMode @@ -207,110 +208,13 @@ struct Decimal( @staticmethod fn PI() -> Decimal: - """ - Returns the value of pi (π) as a Decimal. - - Returns: - A Decimal representation of pi with maximum precision. - """ - return Decimal(0x41B65F29, 0xB143885, 0x6582A536, 0x1C0000) + """Returns the value of pi (π) as a Decimal.""" + return decimojo.constants.PI() @staticmethod fn E() -> Decimal: - """ - Returns the value of Euler's number (e) as a Decimal. - - Returns: - A Decimal representation of Euler's number with maximum precision. - """ - return Decimal(0x857AED5A, 0xEBECDE35, 0x57D519AB, 0x1C0000) - - @staticmethod - fn E2() -> Decimal: - return Decimal(0xE4DFDCAE, 0x89F7E295, 0xEEC0D6E9, 0x1C0000) - - @staticmethod - fn E3() -> Decimal: - return Decimal(0x236454F7, 0x62055A80, 0x40E65DE2, 0x1B0000) - - @staticmethod - fn E4() -> Decimal: - return Decimal(0x7121EFD3, 0xFB318FB5, 0xB06A87FB, 0x1B0000) - - @staticmethod - fn E5() -> Decimal: - return Decimal(0xD99BD974, 0x9F4BE5C7, 0x2FF472E3, 0x1A0000) - - @staticmethod - fn E6() -> Decimal: - return Decimal(0xADB57A66, 0xBD7A423F, 0x825AD8FF, 0x1A0000) - - @staticmethod - fn E7() -> Decimal: - return Decimal(0x22313FCF, 0x64D5D12F, 0x236F230A, 0x190000) - - @staticmethod - fn E8() -> Decimal: - return Decimal(0x1E892E63, 0xD1BF8B5C, 0x6051E812, 0x190000) - - @staticmethod - fn E9() -> Decimal: - return Decimal(0x34FAB691, 0xE7CD8DEA, 0x1A2EB6C3, 0x180000) - - @staticmethod - fn E10() -> Decimal: - return Decimal(0xBA7F4F65, 0x58692B62, 0x472BDD8F, 0x180000) - - @staticmethod - fn E11() -> Decimal: - return Decimal(0x8C2C6D20, 0x2A86F9E7, 0xC176BAAE, 0x180000) - - @staticmethod - fn E12() -> Decimal: - return Decimal(0xE924992A, 0x31CDC314, 0x3496C2C4, 0x170000) - - @staticmethod - fn E13() -> Decimal: - return Decimal(0x220130DB, 0xC386029A, 0x8EF393FB, 0x170000) - - @staticmethod - fn E14() -> Decimal: - return Decimal(0x3A24795C, 0xC412DF01, 0x26DBB5A0, 0x160000) - - @staticmethod - fn E15() -> Decimal: - """Returns the value of e^15 as a Decimal.""" - return Decimal(0x6C1248BD, 0x90456557, 0x69A0AD8C, 0x160000) - - @staticmethod - fn E16() -> Decimal: - """Returns the value of e^16 as a Decimal.""" - return Decimal(0xB46A97D, 0x90655BBD, 0x1CB66B18, 0x150000) - - @staticmethod - fn E32() -> Decimal: - """Returns the value of e^32 as a Decimal.""" - return Decimal(0x18420EB, 0xCC2501E6, 0xFF24A138, 0xF0000) - - @staticmethod - fn E05() -> Decimal: - """Returns the value of e^0.5 = e^(1/2) as a Decimal.""" - return Decimal(0x8E99DD66, 0xC210E35C, 0x3545E717, 0x1C0000) - - @staticmethod - fn E025() -> Decimal: - """Returns the value of e^0.25 = e^(1/4) as a Decimal.""" - return Decimal(0xB43646F1, 0x2654858A, 0x297D3595, 0x1C0000) - - @staticmethod - fn LN10() -> Decimal: - """ - Returns the natural logarithm of 10 as a Decimal. - - Returns: - A Decimal representation of ln(10) with maximum precision. - """ - return Decimal(0x9FA69733, 0x1414B220, 0x4A668998, 0x1C0000) + """Returns the value of Euler's number (e) as a Decimal.""" + return decimojo.constants.E() # ===------------------------------------------------------------------=== # # Constructors and life time dunder methods @@ -1039,7 +943,7 @@ struct Decimal( Decimal.from_words(low, mid, high, flags). """ return ( - "Decimal.from_words(" + "Decimal(" + hex(self.low) + ", " + hex(self.mid) @@ -1369,6 +1273,12 @@ struct Decimal( except e: raise Error("Error in `Decimal.exp()`: ", e) + fn ln(self) raises -> Self: + try: + return decimojo.exponential.ln(self) + except e: + raise Error("Error in `Decimal.ln()`: ", e) + fn round( self, ndigits: Int = 0, diff --git a/src/decimojo/exponential.mojo b/src/decimojo/exponential.mojo index 25cfe5c..4d97c2b 100644 --- a/src/decimojo/exponential.mojo +++ b/src/decimojo/exponential.mojo @@ -32,6 +32,7 @@ import math as builtin_math import testing +import decimojo.constants import decimojo.special import decimojo.utility @@ -294,91 +295,91 @@ fn exp(x: Decimal) raises -> Decimal: var x_int = Int(x) if x.is_one(): - return Decimal.E() + return decimojo.constants.E() elif x_int < 1: - var d05 = Decimal(5, 0, 0, scale=1, sign=False) # 0.5 - var d025 = Decimal(25, 0, 0, scale=2, sign=False) # 0.25 + var M0D5 = Decimal(5, 0, 0, 1 << 16) # 0.5 + var M0D25 = Decimal(25, 0, 0, 2 << 16) # 0.25 - if x < d025: # 0 < x < 0.25 + if x < M0D25: # 0 < x < 0.25 return exp_series(x) - elif x < d05: # 0.25 <= x < 0.5 - exp_chunk = Decimal.E025() - remainder = x - d025 + elif x < M0D5: # 0.25 <= x < 0.5 + exp_chunk = decimojo.constants.E0D25() + remainder = x - M0D25 else: # 0.5 <= x < 1 - exp_chunk = Decimal.E05() - remainder = x - d05 + exp_chunk = decimojo.constants.E0D5() + remainder = x - M0D5 elif x_int == 1: # 1 <= x < 2, chunk = 1 - exp_chunk = Decimal.E() + exp_chunk = decimojo.constants.E() remainder = x - x_int elif x_int == 2: # 2 <= x < 3, chunk = 2 - exp_chunk = Decimal.E2() + exp_chunk = decimojo.constants.E2() remainder = x - x_int elif x_int == 3: # 3 <= x < 4, chunk = 3 - exp_chunk = Decimal.E3() + exp_chunk = decimojo.constants.E3() remainder = x - x_int elif x_int == 4: # 4 <= x < 5, chunk = 4 - exp_chunk = Decimal.E4() + exp_chunk = decimojo.constants.E4() remainder = x - x_int elif x_int == 5: # 5 <= x < 6, chunk = 5 - exp_chunk = Decimal.E5() + exp_chunk = decimojo.constants.E5() remainder = x - x_int elif x_int == 6: # 6 <= x < 7, chunk = 6 - exp_chunk = Decimal.E6() + exp_chunk = decimojo.constants.E6() remainder = x - x_int elif x_int == 7: # 7 <= x < 8, chunk = 7 - exp_chunk = Decimal.E7() + exp_chunk = decimojo.constants.E7() remainder = x - x_int elif x_int == 8: # 8 <= x < 9, chunk = 8 - exp_chunk = Decimal.E8() + exp_chunk = decimojo.constants.E8() remainder = x - x_int elif x_int == 9: # 9 <= x < 10, chunk = 9 - exp_chunk = Decimal.E9() + exp_chunk = decimojo.constants.E9() remainder = x - x_int elif x_int == 10: # 10 <= x < 11, chunk = 10 - exp_chunk = Decimal.E10() + exp_chunk = decimojo.constants.E10() remainder = x - x_int elif x_int == 11: # 11 <= x < 12, chunk = 11 - exp_chunk = Decimal.E11() + exp_chunk = decimojo.constants.E11() remainder = x - x_int elif x_int == 12: # 12 <= x < 13, chunk = 12 - exp_chunk = Decimal.E12() + exp_chunk = decimojo.constants.E12() remainder = x - x_int elif x_int == 13: # 13 <= x < 14, chunk = 13 - exp_chunk = Decimal.E13() + exp_chunk = decimojo.constants.E13() remainder = x - x_int elif x_int == 14: # 14 <= x < 15, chunk = 14 - exp_chunk = Decimal.E14() + exp_chunk = decimojo.constants.E14() remainder = x - x_int elif x_int == 15: # 15 <= x < 16, chunk = 15 - exp_chunk = Decimal.E15() + exp_chunk = decimojo.constants.E15() remainder = x - x_int elif x_int < 32: # 16 <= x < 32, chunk = 16 num_chunks = x_int >> 4 - exp_chunk = Decimal.E16() + exp_chunk = decimojo.constants.E16() remainder = x - (num_chunks << 4) else: # chunk = 32 num_chunks = x_int >> 5 - exp_chunk = Decimal.E32() + exp_chunk = decimojo.constants.E32() remainder = x - (num_chunks << 5) # Calculate e^(chunk * num_chunks) = (e^chunk)^num_chunks @@ -417,7 +418,7 @@ fn exp_series(x: Decimal) raises -> Decimal: return Decimal.ONE() # For x with very small magnitude, just use 1+x approximation - if abs(x) == Decimal("1e-28"): + if abs(x) == Decimal(1, 0, 0, 28 << 16): return Decimal.ONE() + x # Initialize result and term @@ -441,3 +442,268 @@ fn exp_series(x: Decimal) raises -> Decimal: result = result + term return result + + +fn ln(x: Decimal) raises -> Decimal: + """ + Calculates the natural logarithm (ln) of a Decimal value. + + Args: + x: The Decimal value to compute the natural logarithm of. + + Returns: + A Decimal approximation of ln(x). + + Raises: + Error: If x is less than or equal to zero. + + Notes: + This implementation uses range reduction to improve accuracy and performance. + """ + + # print("DEBUG: ln(x) called with x =", x) + + # Handle special cases + if x.is_negative() or x.is_zero(): + raise Error( + "Error in ln(): Cannot compute logarithm of a non-positive number" + ) + + if x.is_one(): + return Decimal.ZERO() + + # Special cases for common values + if x == decimojo.constants.E(): + return Decimal.ONE() + + # For values close to 1, use series expansion directly + if Decimal(95, 0, 0, 2 << 16) <= x <= Decimal(105, 0, 0, 2 << 16): + return ln_series(x - Decimal.ONE()) + + # For all other values, use range reduction + # ln(x) = ln(m * 2^p * 10^q) = ln(m) + p*ln(2) + q*ln(10), where 1 <= m < 2 + + var m: Decimal = x + var p: Int = 0 + var q: Int = 0 + + # Step 1: handle powers of 10 for large values + if x >= decimojo.constants.M10(): + # Repeatedly divide by 10 until m < 10 + while m >= decimojo.constants.M10(): + m = m / decimojo.constants.M10() + q += 1 + elif x < Decimal(1, 0, 0, 1 << 16): + # Repeatedly multiply by 10 until m >= 0.1 + while m < Decimal(1, 0, 0, 1 << 16): + m = m * decimojo.constants.M10() + q -= 1 + + # Now 0.1 <= m < 10 + # Step 2: normalize to [0.5, 2) using powers of 2 + if m >= decimojo.constants.M2(): + # Repeatedly divide by 2 until m < 2 + while m >= decimojo.constants.M2(): + m = m / decimojo.constants.M2() + p += 1 + elif m < Decimal(5, 0, 0, 1 << 16): + # Repeatedly multiply by 2 until m >= 0.5 + while m < Decimal(5, 0, 0, 1 << 16): + m = m * decimojo.constants.M2() + p -= 1 + + # Now 0.5 <= m < 2 + var ln_m: Decimal + + # Use precomputed values and series expansion for accuracy and performance + if m < Decimal.ONE(): + # For 0.5 <= m < 1 + if m >= Decimal(9, 0, 0, 1 << 16): + ln_m = ( + ln_series( + (m - Decimal(9, 0, 0, 1 << 16)) + * decimojo.constants.INV0D9() + ) + + decimojo.constants.LN0D9() + ) + elif m >= Decimal(8, 0, 0, 1 << 16): + ln_m = ( + ln_series( + (m - Decimal(8, 0, 0, 1 << 16)) + * decimojo.constants.INV0D8() + ) + + decimojo.constants.LN0D8() + ) + elif m >= Decimal(7, 0, 0, 1 << 16): + ln_m = ( + ln_series( + (m - Decimal(7, 0, 0, 1 << 16)) + * decimojo.constants.INV0D7() + ) + + decimojo.constants.LN0D7() + ) + elif m >= Decimal(6, 0, 0, 1 << 16): + ln_m = ( + ln_series( + (m - Decimal(6, 0, 0, 1 << 16)) + * decimojo.constants.INV0D6() + ) + + decimojo.constants.LN0D6() + ) + else: # 0.5 <= m < 0.6 + ln_m = ( + ln_series( + (m - Decimal(5, 0, 0, 1 << 16)) + * decimojo.constants.INV0D5() + ) + + decimojo.constants.LN0D5() + ) + + else: + # For 1 < m < 2 + if m < Decimal(11, 0, 0, 1 << 16): # 1 < m < 1.1 + ln_m = ln_series(m - Decimal.ONE()) + elif m < Decimal(12, 0, 0, 1 << 16): # 1.1 <= m < 1.2 + ln_m = ( + ln_series( + (m - Decimal(11, 0, 0, 1 << 16)) + * decimojo.constants.INV1D1() + ) + + decimojo.constants.LN1D1() + ) + elif m < Decimal(13, 0, 0, 1 << 16): # 1.2 <= m < 1.3 + ln_m = ( + ln_series( + (m - Decimal(12, 0, 0, 1 << 16)) + * decimojo.constants.INV1D2() + ) + + decimojo.constants.LN1D2() + ) + elif m < Decimal(14, 0, 0, 1 << 16): # 1.3 <= m < 1.4 + ln_m = ( + ln_series( + (m - Decimal(13, 0, 0, 1 << 16)) + * decimojo.constants.INV1D3() + ) + + decimojo.constants.LN1D3() + ) + elif m < Decimal(15, 0, 0, 1 << 16): # 1.4 <= m < 1.5 + ln_m = ( + ln_series( + (m - Decimal(14, 0, 0, 1 << 16)) + * decimojo.constants.INV1D4() + ) + + decimojo.constants.LN1D4() + ) + elif m < Decimal(16, 0, 0, 1 << 16): # 1.5 <= m < 1.6 + ln_m = ( + ln_series( + (m - Decimal(15, 0, 0, 1 << 16)) + * decimojo.constants.INV1D5() + ) + + decimojo.constants.LN1D5() + ) + elif m < Decimal(17, 0, 0, 1 << 16): # 1.6 <= m < 1.7 + ln_m = ( + ln_series( + (m - Decimal(16, 0, 0, 1 << 16)) + * decimojo.constants.INV1D6() + ) + + decimojo.constants.LN1D6() + ) + elif m < Decimal(18, 0, 0, 1 << 16): # 1.7 <= m < 1.8 + ln_m = ( + ln_series( + (m - Decimal(17, 0, 0, 1 << 16)) + * decimojo.constants.INV1D7() + ) + + decimojo.constants.LN1D7() + ) + elif m < Decimal(19, 0, 0, 1 << 16): # 1.8 <= m < 1.9 + ln_m = ( + ln_series( + (m - Decimal(18, 0, 0, 1 << 16)) + * decimojo.constants.INV1D8() + ) + + decimojo.constants.LN1D8() + ) + else: # 1.9 <= m < 2 + ln_m = ( + ln_series( + (m - Decimal(19, 0, 0, 1 << 16)) + * decimojo.constants.INV1D9() + ) + + decimojo.constants.LN1D9() + ) + + # Combine result: ln(x) = ln(m) + p*ln(2) + q*ln(10) + var result = ln_m + + # Add power of 2 contribution + if p != 0: + result = result + Decimal(p) * decimojo.constants.LN2() + + # Add power of 10 contribution + if q != 0: + result = result + Decimal(q) * decimojo.constants.LN10() + + return result + + +fn ln_series(z: Decimal) raises -> Decimal: + """ + Calculates ln(1+z) using Taylor series expansion at 1. + For best accuracy, |z| should be small (< 0.5). + + Args: + z: The value to compute ln(1+z) for. + + Returns: + A Decimal approximation of ln(1+z). + + Notes: + Uses the series: ln(1+z) = z - z²/2 + z³/3 - z⁴/4 + ... + This series converges fastest when |z| is small. + """ + + # print("DEBUG: ln_series(z) called with z =", z) + + var max_terms = 500 + + # For z=0, ln(1+z) = ln(1) = 0 + if z.is_zero(): + return Decimal.ZERO() + + # For z with very small magnitude, just use z approximation + if abs(z) == Decimal(1, 0, 0, 28 << 16): + return z + + # Initialize result and term + var result = Decimal.ZERO() + var term = z + var neg: Bool = False + + # Calculate terms iteratively + # term[i] = (-1)^(i+1) * z^i / i + + for i in range(1, max_terms + 1): + if neg: + result = result - term + else: + result = result + term + + neg = not neg # Alternate sign + + if i <= 20: + term = term * z * decimojo.constants.N_DIVIDE_NEXT(i) + else: + term = term * z * Decimal(i) / Decimal(i + 1) + + # Check for convergence + if term.is_zero(): + # print("DEBUG: i = ", i) + break + + # print("DEBUG: result =", result) + + return result diff --git a/tests/test_divide.mojo b/tests/test_divide.mojo index 3ae347b..347b355 100644 --- a/tests/test_divide.mojo +++ b/tests/test_divide.mojo @@ -417,8 +417,8 @@ fn test_large_numbers() raises: testing.assert_equal(a53, Decimal("7922816251426433759354395033.4")) # 54. Division where result approaches max - var large_num = Decimal.MAX() / Decimal("2") - var a54 = large_num * Decimal("2") + var large_num = Decimal.MAX() / Decimal("3") + var a54 = large_num * Decimal("3") testing.assert_true( a54 <= Decimal.MAX(), "Case 54: Division where result approaches max failed", diff --git a/tests/test_ln.mojo b/tests/test_ln.mojo new file mode 100644 index 0000000..a14515c --- /dev/null +++ b/tests/test_ln.mojo @@ -0,0 +1,272 @@ +""" +Comprehensive tests for the ln() function in the DeciMojo library. +Tests various cases including basic values, mathematical identities, +and edge cases to ensure proper calculation of the natural logarithm. +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode +from decimojo.exponential import ln + + +fn test_basic_ln_values() raises: + """Test basic natural logarithm values.""" + print("Testing basic natural logarithm values...") + + # Test case 1: ln(1) = 0 + var one = Decimal("1") + var result1 = ln(one) + testing.assert_equal( + String(result1), "0", "ln(1) should be 0, got " + String(result1) + ) + + # Test case 2: ln(e) = 1 + var e = Decimal("2.718281828459045235360287471") + var result_e = ln(e) + testing.assert_true( + String(result_e).startswith("1.00000000000000000000"), + "ln(e) should be approximately 1, got " + String(result_e), + ) + + # Test case 3: ln(10) + var ten = Decimal("10") + var result_ten = ln(ten) + testing.assert_true( + String(result_ten).startswith("2.30258509299404568401799145"), + "ln(10) should be approximately 2.30258509299404568401799145..., got " + + String(result_ten), + ) + + # Test case 4: ln(0.1) + var tenth = Decimal("0.1") + var result_tenth = ln(tenth) + testing.assert_true( + String(result_tenth).startswith("-2.302585092994045684017991454"), + "ln(0.1) should be approximately -2.302585092994045684017991454...," + " got " + + String(result_tenth), + ) + + print("✓ Basic natural logarithm values tests passed!") + + +fn test_fractional_ln_values() raises: + """Test natural logarithm values with fractional inputs.""" + print("Testing natural logarithm values with fractional inputs...") + + # Test case 5: ln(0.5) + var half = Decimal("0.5") + var result_half = ln(half) + testing.assert_true( + String(result_half).startswith("-0.693147180559945309417232121"), + "ln(0.5) should be approximately -0.693147180559945309417232121...," + " got " + + String(result_half), + ) + + # Test case 6: ln(2) + var two = Decimal("2") + var result_two = ln(two) + testing.assert_true( + String(result_two).startswith("0.693147180559945309417232121"), + "ln(2) should be approximately 0.693147180559945309417232121..., got " + + String(result_two), + ) + + # Test case 7: ln(5) + var five = Decimal("5") + var result_five = ln(five) + testing.assert_true( + String(result_five).startswith("1.609437912434100374600759333"), + "ln(5) should be approximately 1.609437912434100374600759333..., got " + + String(result_five), + ) + + print("✓ Fractional natural logarithm values tests passed!") + + +fn test_mathematical_identities() raises: + """Test mathematical identities related to the natural logarithm.""" + print("Testing mathematical identities for natural logarithm...") + + # Test case 8: ln(a * b) = ln(a) + ln(b) + var a = Decimal("2") + var b = Decimal("3") + var ln_a_times_b = ln(a * b) + var ln_a_plus_ln_b = ln(a) + ln(b) + testing.assert_true( + abs(ln_a_times_b - ln_a_plus_ln_b) < Decimal("0.0000000001"), + "ln(a * b) should equal ln(a) + ln(b) within tolerance", + ) + + # Test case 9: ln(a / b) = ln(a) - ln(b) + var ln_a_div_b = ln(a / b) + var ln_a_minus_ln_b = ln(a) - ln(b) + testing.assert_true( + abs(ln_a_div_b - ln_a_minus_ln_b) < Decimal("0.0000000001"), + "ln(a / b) should equal ln(a) - ln(b) within tolerance", + ) + + # Test case 10: ln(e^x) = x + var x = Decimal("5") + var ln_e_to_x = ln(dm.exponential.exp(x)) + testing.assert_true( + abs(ln_e_to_x - x) < Decimal("0.0000000001"), + "ln(e^x) should equal x within tolerance", + ) + + print("✓ Mathematical identities tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for natural logarithm function.""" + print("Testing edge cases for natural logarithm function...") + + # Test case 11: ln(0) should raise an exception + var zero = Decimal("0") + var exception_caught = False + try: + var _ln0 = ln(zero) + testing.assert_equal(True, False, "ln(0) should raise an exception") + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + # Test case 12: ln of a negative number should raise an exception + var neg_one = Decimal("-1") + exception_caught = False + try: + var _ln = ln(neg_one) + testing.assert_equal( + True, False, "ln of a negative number should raise an exception" + ) + except: + exception_caught = True + testing.assert_equal(exception_caught, True) + + # Test case 13: ln of a very small number + var very_small = Decimal("0.000000000000000000000000001") + var result_small = ln(very_small) + testing.assert_true( + String(result_small).startswith("-62.16979751083923346848576927"), + "ln of a very small number should be -62.16979751083923346848576927...," + " but got {}".format(result_small), + ) + + # Test case 14: ln of a very large number + var very_large = Decimal("10000000000000000000000000000") + var result_large = ln(very_large) + testing.assert_true( + String(result_large).startswith("64.4723"), + "ln of a very large number should be 64.4723..., but got {}".format( + result_large + ), + ) + + print("✓ Edge cases tests passed!") + + +fn test_precision() raises: + """Test precision of natural logarithm calculations.""" + print("Testing precision of natural logarithm calculations...") + + # Test case 15: ln(2) with high precision + var two = Decimal("2") + var result_two = ln(two) + testing.assert_true( + String(result_two).startswith("0.693147180559945309417232121"), + "ln(2) with high precision should be accurate", + ) + + # Test case 16: ln(10) with high precision + var ten = Decimal("10") + var result_ten = ln(ten) + testing.assert_true( + String(result_ten).startswith("2.30258509299404568401"), + "ln(10) with high precision should be 2.30258509299404568401..., but" + " got {}".format(result_ten), + ) + + print("✓ Precision tests passed!") + + +fn test_range_of_values() raises: + """Test natural logarithm function across a range of values.""" + print("Testing natural logarithm function across a range of values...") + + # Test case 17: ln(x) for x in range (3, 10) + testing.assert_true( + Decimal(3).ln() > Decimal(0), "ln(x) should be positive for x > 2" + ) + testing.assert_true( + Decimal(10).ln() > Decimal(2), + "ln(x) should be greater than x for x > 2", + ) + + # Test case 18: ln(x) for x in range (0.1, 1, 0.1) + + testing.assert_true( + Decimal("0.1").ln() < Decimal(0), "ln(x) should be negative for x < 1" + ) + testing.assert_true( + Decimal("0.9").ln() < Decimal(0), "ln(x) should be negative for x < 1" + ) + + print("✓ Range of values tests passed!") + + +fn test_special_cases() raises: + """Test special cases for natural logarithm function.""" + print("Testing special cases for natural logarithm function...") + + # Test case 19: ln(1) = 0 (revisited) + var one = Decimal("1") + var result_one = ln(one) + testing.assert_equal(String(result_one), "0", "ln(1) should be exactly 0") + + # Test case 20: ln(e) close to 1 + var e = Decimal("2.718281828459045235360287471") + var result_e = ln(e) + testing.assert_true( + abs(result_e - Decimal("1")) < Decimal("0.0000000001"), + "ln(e) should be very close to 1", + ) + + print("✓ Special cases tests passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Natural Logarithm Function Tests") + print("=========================================") + + run_test_with_error_handling(test_basic_ln_values, "Basic ln values test") + run_test_with_error_handling( + test_fractional_ln_values, "Fractional ln values test" + ) + run_test_with_error_handling( + test_mathematical_identities, "Mathematical identities test" + ) + run_test_with_error_handling(test_edge_cases, "Edge cases test") + run_test_with_error_handling(test_precision, "Precision test") + run_test_with_error_handling(test_range_of_values, "Range of values test") + run_test_with_error_handling(test_special_cases, "Special cases test") + + print("All natural logarithm function tests passed!")