diff --git a/README.md b/README.md index 1b06226..a2a6a52 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,39 @@ This project draws inspiration from several established decimal implementations ## Nomenclature -DeciMojo combines "Decimal" and "Mojo" - reflecting both its purpose (decimal arithmetic) and the programming language it's implemented in. The name highlights the project's focus on bringing precise decimal calculations to the Mojo ecosystem. +DeciMojo combines "Decimal" and "Mojo" - reflecting both its purpose (decimal arithmetic) and the programming language it's implemented in. The name emphasizes the project's focus on bringing precise decimal calculations to the Mojo ecosystem. -For brevity, you can also refer to it "decimo" (derived from the Latin root "decimus" meaning "tenth"). +For brevity, you can refer to it as "deci" (derived from the Latin root "decimus" meaning "tenth"). + +When you add `from decimojo import dm, Decimal` at the top of your script, this imports the `decimojo` module into your namespace with the shorter alias `dm` and directly imports the `Decimal` type. This is equivalent to: + +```mojo +import decimojo as dm +from decimojo import Decimal +``` + +## Advantages + +DeciMojo provides exceptional computational precision without sacrificing performance. It maintains accuracy throughout complex calculations where floating-point or other decimal implementations might introduce subtle errors. + +Consider the square root of `15.9999`. When comparing DeciMojo's implementation with Python's decimal module (both rounded to 16 decimal places): + +- DeciMojo calculates: `3.9999874999804687` +- Python's decimal returns: `3.9999874999804685` + +The mathematically correct value (to 50+ digits) is: +`3.9999874999804686889646053303778122644631365491812...` + +When rounded to 16 decimal places, the correct result is `3.9999874999804687`, confirming that DeciMojo produces the more accurate result in this case. + +```log +Function: sqrt() +Decimal value: 15.9999 +DeciMojo result: 3.9999874999804686889646053305 +Python's decimal result: 3.9999874999804685 +``` + +This precision advantage becomes increasingly important in financial, scientific, and engineering calculations where small rounding errors can compound into significant discrepancies. ## Status @@ -44,20 +74,20 @@ Rome wasn't built in a day. DeciMojo is currently under active development, posi ### Make it Fast ⏳ (IN PROGRESS & FUTURE WORK) -- Performance optimization on basic operations (+ - * /) in progress +- Performance optimization on basic operations (+ - * /) is mostly finished ([PR#16](https://github.com/forFudan/DeciMojo/pull/16), [PR#20](https://github.com/forFudan/DeciMojo/pull/20), [PR#21](https://github.com/forFudan/DeciMojo/pull/21)). - Regular benchmarking against Python's `decimal` module (see `bench/` folder) - Performance optimization on other functions are acknowleged but not currently prioritized ## Examples -The `Decimal` type can represent at most 29 significant digits and 28 digits after the decimal point. If the significant digits of a decimal value exceeds the maximum value (`2^96 - 1`), it either casts and error, or it is rounded until the siginificant digits are within the maximum value. For example, for the number `8.8888888888888888888888888888`, the significant digits (29 eights)exceeds the maximum value. It will then be rounded into `8.888888888888888888888888888` (28 eights). +The `Decimal` type can represent values with up to 29 significant digits and a maximum of 28 digits after the decimal point. When a value exceeds the maximum representable value (`2^96 - 1`), DeciMojo either raises an error or rounds the value to fit within these constraints. For example, the significant digits of `8.8888888888888888888888888888` (29 eights total with 28 after the decimal point) exceeds the maximum representable value (`2^96 - 1`) and is automatically rounded to `8.888888888888888888888888888` (28 eights total with 27 after the decimal point). -Here are 10 key examples highlighting the most important features of the `Decimal` type in its current state: +Here are 8 key examples highlighting the most important features of the `Decimal` type in its current state: ### 1. Fixed-Point Precision for Financial Calculations ```mojo -from decimojo.prelude import * +from decimojo import dm, Decimal # The classic floating-point problem print(0.1 + 0.2) # 0.30000000000000004 (not exactly 0.3) @@ -75,7 +105,7 @@ var tax = price * tax_rate # Exactly 1.449275 var total = price + tax # Exactly 21.439275 ``` -### 2. Basic Arithmetic with Proper Rounding +### 2. Basic Arithmetic with Proper Banker's Rounding ```mojo # Addition with different scales @@ -88,15 +118,15 @@ var c = Decimal("50") var d = Decimal("75.25") print(c - d) # -25.25 -# Multiplication preserving full precision -var e = Decimal("12.34") +# Multiplication with banker's rounding (round to even) +var e = Decimal("12.345") var f = Decimal("5.67") -print(e * f) # 69.9678 (all digits preserved) +print(round(e * f, 2)) # 69.96 (rounds to nearest even) -# Division with repeating decimals handled precisely -var g = Decimal("1") +# Division with banker's rounding +var g = Decimal("10") var h = Decimal("3") -print(g / h) # 0.3333333333333333333333333333 (to precision limit) +print(round(g / h, 2)) # 3.33 (rounded banker's style) ``` ### 3. Scale and Precision Management @@ -106,19 +136,17 @@ print(g / h) # 0.3333333333333333333333333333 (to precision limit) var d1 = Decimal("123.45") print(d1.scale()) # 2 -var d2 = Decimal("123.4500") -print(d2.scale()) # 4 - -# Arithmetic operations combine scales appropriately -var sum = Decimal("0.123") + Decimal("0.45") # Takes larger scale -print(sum) # 0.573 - -var product = Decimal("0.12") * Decimal("0.34") # Sums the scales -print(product) # 0.0408 +# Precision control with explicit rounding +var d2 = Decimal("123.456") +print(d2.round_to_scale(1)) # 123.5 (banker's rounding) # High precision is preserved (up to 28 decimal places) var precise = Decimal("0.1234567890123456789012345678") print(precise) # 0.1234567890123456789012345678 + +# Truncation to specific number of digits +var large_num = Decimal("123456.789") +print(truncate_to_digits(large_num, 4)) # 1235 (banker's rounded) ``` ### 4. Sign Handling and Absolute Value @@ -129,10 +157,6 @@ var pos = Decimal("123.45") var neg = -pos print(neg) # -123.45 -# Multiple negations -var back_to_pos = -(-pos) -print(back_to_pos) # 123.45 - # Absolute value var abs_val = abs(Decimal("-987.65")) print(abs_val) # 987.65 @@ -140,7 +164,10 @@ print(abs_val) # 987.65 # Sign checking print(Decimal("-123.45").is_negative()) # True print(Decimal("0").is_negative()) # False -print(Decimal("123.45").is_negative()) # False + +# Sign preservation in multiplication +print(Decimal("-5") * Decimal("3")) # -15 +print(Decimal("-5") * Decimal("-3")) # 15 ``` ### 5. Advanced Mathematical Operations @@ -148,46 +175,24 @@ print(Decimal("123.45").is_negative()) # False ```mojo from decimojo.mathematics import sqrt -# Integer powers -var squared = Decimal("3") ** 2 -print(squared) # 9 - -# Negative powers (reciprocals) -var recip = Decimal("2") ** (-1) -print(recip) # 0.5 - -# Square root with high precision +# Highly accurate square root implementation var root2 = sqrt(Decimal("2")) -print(root2) # 1.414213562373095048801688724... - -# Special cases -print(Decimal("10") ** 0) # 1 (anything to power 0) -print(Decimal("1") ** 20) # 1 (1 to any power) -print(Decimal("0") ** 5) # 0 (0 to positive power) -``` - -### 6. Type Conversions and Interoperability - -```mojo -var d = Decimal("123.456") +print(root2) # 1.4142135623730950488016887242... -# Converting to string (for display or serialization) -var str_val = String(d) -print(str_val) # "123.456" +# Square root of imperfect squares +var root_15_9999 = sqrt(Decimal("15.9999")) +print(root_15_9999) # 3.9999874999804686889646053305... -# Getting internal representation -print(repr(d)) # Shows internal state +# Integer powers with fast binary exponentiation +var cubed = Decimal("3") ** 3 +print(cubed) # 27 -# Converting to int (truncates toward zero) -var i = Int(d) -print(i) # 123 - -# Converting to float (may lose precision) -var f = Float64(d) -print(f) # 123.456 +# Negative powers (reciprocals) +var recip = Decimal("2") ** (-1) +print(recip) # 0.5 ``` -### 7. Handling Edge Cases and Errors +### 6. Robust Edge Case Handling ```mojo # Division by zero is properly caught @@ -202,33 +207,36 @@ try: except: print("Zero to negative power properly detected") -# Smallest representable positive value -var tiny = Decimal("0." + "0" * 27 + "1") # 28 decimal places -print(tiny) # 0.0000000000000000000000000001 +# Overflow detection and prevention +var max_val = Decimal.MAX() +try: + var overflow = max_val * Decimal("2") +except: + print("Overflow correctly detected") ``` -### 8. Equality and Zero Comparisons +### 7. Equality and Comparison Operations ```mojo -# Equal values with different representations +# Equal values with different scales var a = Decimal("123.4500") var b = Decimal("123.45") print(a == b) # True (numeric value equality) -# Zero values with different scales -var z1 = Decimal("0") -var z2 = Decimal("0.000") -print(z1 == z2) # True - -# Zero detection -print(z1.is_zero()) # True -print(z2.is_zero()) # True +# Comparison operators +var c = Decimal("100") +var d = Decimal("200") +print(c < d) # True +print(c <= d) # True +print(c > d) # False +print(c >= d) # False +print(c != d) # True ``` -### 9. Real World Financial Examples +### 8. Real World Financial Examples ```mojo -# Monthly loan payment calculation +# Monthly loan payment calculation with precise interest var principal = Decimal("200000") # $200,000 loan var annual_rate = Decimal("0.05") # 5% interest rate var monthly_rate = annual_rate / Decimal("12") @@ -239,39 +247,6 @@ var numerator = monthly_rate * (Decimal("1") + monthly_rate) ** 360 var denominator = (Decimal("1") + monthly_rate) ** 360 - Decimal("1") var payment = principal * (numerator / denominator) print("Monthly payment: $" + String(round(payment, 2))) # $1,073.64 - -# Correct handling of multiple items and discounts -var item1 = Decimal("29.99") -var item2 = Decimal("59.99") -var subtotal = item1 + item2 # 89.98 -var discount = subtotal * Decimal("0.15") # 15% off -var after_discount = subtotal - discount -var tax = after_discount * Decimal("0.08") # 8% tax -var final = after_discount + tax -print("Final price: $" + String(round(final, 2))) -``` - -### 10. Maximum Precision and Limit Testing - -```mojo -# Maximum value supported -var max_val = Decimal.MAX() -print(max_val) # 79228162514264337593543950335 - -# Minimum value supported -var min_val = Decimal.MIN() -print(min_val) # -79228162514264337593543950335 - -# Operations near limits -var near_max = Decimal("79228162514264337593543950334") # MAX() - 1 -var still_valid = near_max + Decimal("1") -print(still_valid == max_val) # True - -# Maximum precision for high-accuracy scientific calculations -var pi = Decimal("3.1415926535897932384626433832") -var radius = Decimal("2.5") -var area = pi * (radius ** 2) -print("Circle area: " + String(area)) # Precisely calculated area ``` ## Tests and benches @@ -292,7 +267,7 @@ If you find DeciMojo useful for your research, consider listing it in your citat title = {DeciMojo: A fixed-point decimal arithmetic library in Mojo}, url = {https://github.com/forFudan/DeciMojo}, version = {0.1.0}, - note = {Computer Software} + note = {Computer Software} } ``` @@ -302,7 +277,6 @@ Copyright 2025 Yuhao Zhu Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +[http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - diff --git a/benches/bench.mojo b/benches/bench.mojo index 23aea6e..07954ba 100644 --- a/benches/bench.mojo +++ b/benches/bench.mojo @@ -2,6 +2,7 @@ from bench_add import main as bench_add from bench_subtract import main as bench_subtract from bench_multiply import main as bench_multiply from bench_divide import main as bench_divide +from bench_sqrt import main as bench_sqrt fn main() raises: @@ -9,3 +10,4 @@ fn main() raises: bench_subtract() bench_multiply() bench_divide() + bench_sqrt() diff --git a/benches/bench_add.mojo b/benches/bench_add.mojo index 8db56b7..681614e 100644 --- a/benches/bench_add.mojo +++ b/benches/bench_add.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal addition operations. Compares performance against Python's decimal module. """ -from decimojo.prelude import * +from decimojo import dm, Decimal from python import Python, PythonObject from time import perf_counter_ns import time @@ -137,9 +137,7 @@ fn main() raises: "Python decimal precision: " + String(pydecimal.getcontext().prec), log_file, ) - log_print( - "Mojo decimal precision: " + String(Decimal.MAX_PRECISION), log_file - ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) # Define benchmark cases log_print( diff --git a/benches/bench_divide.mojo b/benches/bench_divide.mojo index 4eb06a2..3add59a 100644 --- a/benches/bench_divide.mojo +++ b/benches/bench_divide.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal division operations. Compares performance against Python's decimal module. """ -from decimojo.prelude import * +from decimojo import dm, Decimal from python import Python, PythonObject from time import perf_counter_ns import time @@ -137,9 +137,7 @@ fn main() raises: "Python decimal precision: " + String(pydecimal.getcontext().prec), log_file, ) - log_print( - "Mojo decimal precision: " + String(Decimal.MAX_PRECISION), log_file - ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) # Define benchmark cases log_print( diff --git a/benches/bench_multiply.mojo b/benches/bench_multiply.mojo index cf9f39e..09281ae 100644 --- a/benches/bench_multiply.mojo +++ b/benches/bench_multiply.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal multiplication operations. Compares performance against Python's decimal module. """ -from decimojo.prelude import * +from decimojo import dm, Decimal from python import Python, PythonObject from time import perf_counter_ns import time @@ -137,9 +137,7 @@ fn main() raises: "Python decimal precision: " + String(pydecimal.getcontext().prec), log_file, ) - log_print( - "Mojo decimal precision: " + String(Decimal.MAX_PRECISION), log_file - ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) # Define benchmark cases log_print( diff --git a/benches/bench_sqrt.mojo b/benches/bench_sqrt.mojo new file mode 100644 index 0000000..0387938 --- /dev/null +++ b/benches/bench_sqrt.mojo @@ -0,0 +1,379 @@ +""" +Comprehensive benchmarks for Decimal square root operations. +Compares performance against Python's decimal module with 20 diverse test cases. +""" + +from decimojo import dm, Decimal +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_sqrt_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark( + name: String, + d_mojo: Decimal, + d_py: PythonObject, + iterations: Int, + log_file: PythonObject, +) raises: + """ + Run a benchmark comparing Mojo Decimal sqrt with Python Decimal sqrt. + + Args: + name: Name of the benchmark case. + d_mojo: Mojo Decimal operand. + d_py: Python Decimal operand. + iterations: Number of iterations to run. + log_file: File object for logging results. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Decimal: " + String(d_mojo), log_file) + + # Verify correctness - import math module for Python's sqrt + var math = Python.import_module("math") + var mojo_result = decimojo.sqrt(d_mojo) + var py_result = math.sqrt(d_py) + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = decimojo.sqrt(d_mojo) + var mojo_time = (perf_counter_ns() - t0) / iterations + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = math.sqrt(d_py) + var python_time = (perf_counter_ns() - t0) / iterations + + # Print results with speedup comparison + log_print( + "Mojo Decimal: " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python Decimal: " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(python_time / mojo_time), log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Display benchmark header with system information + log_print("=== DeciMojo Square Root Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 100 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning square root benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Perfect square (small) + var case1_mojo = Decimal("16") + var case1_py = pydecimal.Decimal("16") + run_benchmark( + "Perfect square (small)", + case1_mojo, + case1_py, + iterations, + log_file, + ) + + # Case 2: Perfect square (large) + var case2_mojo = Decimal("1000000") # 1000^2 + var case2_py = pydecimal.Decimal("1000000") + run_benchmark( + "Perfect square (large)", + case2_mojo, + case2_py, + iterations, + log_file, + ) + + # Case 3: Non-perfect square (small irrational) + var case3_mojo = Decimal("2") # sqrt(2) is irrational + var case3_py = pydecimal.Decimal("2") + run_benchmark( + "Non-perfect square (small irrational)", + case3_mojo, + case3_py, + iterations, + log_file, + ) + + # Case 4: Non-perfect square (medium) + var case4_mojo = Decimal("123.456") + var case4_py = pydecimal.Decimal("123.456") + run_benchmark( + "Non-perfect square (medium)", + case4_mojo, + case4_py, + iterations, + log_file, + ) + + # Case 5: Very small number + var case5_mojo = Decimal("0.0000001") + var case5_py = pydecimal.Decimal("0.0000001") + run_benchmark( + "Very small number", + case5_mojo, + case5_py, + iterations, + log_file, + ) + + # Case 6: Very large number + var case6_mojo = Decimal("1" + "0" * 20) # 10^20 + var case6_py = pydecimal.Decimal("1" + "0" * 20) + run_benchmark( + "Very large number", + case6_mojo, + case6_py, + iterations, + log_file, + ) + + # Case 7: Number just above 1 + var case7_mojo = Decimal("1.0000001") + var case7_py = pydecimal.Decimal("1.0000001") + run_benchmark( + "Number just above 1", + case7_mojo, + case7_py, + iterations, + log_file, + ) + + # Case 8: Number just below 1 + var case8_mojo = Decimal("0.9999999") + var case8_py = pydecimal.Decimal("0.9999999") + run_benchmark( + "Number just below 1", + case8_mojo, + case8_py, + iterations, + log_file, + ) + + # Case 9: High precision value + var case9_mojo = Decimal("1.23456789012345678901234567") + var case9_py = pydecimal.Decimal("1.23456789012345678901234567") + run_benchmark( + "High precision value", + case9_mojo, + case9_py, + iterations, + log_file, + ) + + # Case 10: Number with exact square root in decimal + var case10_mojo = Decimal("0.04") # sqrt = 0.2 + var case10_py = pydecimal.Decimal("0.04") + run_benchmark( + "Number with exact square root", + case10_mojo, + case10_py, + iterations, + log_file, + ) + + # Case 11: Number close to a perfect square + var case11_mojo = Decimal("99.99") # Close to 10² + var case11_py = pydecimal.Decimal("99.99") + run_benchmark( + "Number close to a perfect square", + case11_mojo, + case11_py, + iterations, + log_file, + ) + + # Case 12: Even larger perfect square + var case12_mojo = Decimal("1000000000") # 31622.78...^2 + var case12_py = pydecimal.Decimal("1000000000") + run_benchmark( + "Very large perfect square", + case12_mojo, + case12_py, + iterations, + log_file, + ) + + # Case 13: Number with repeating pattern in result + var case13_mojo = Decimal("3") # sqrt(3) has repeating pattern + var case13_py = pydecimal.Decimal("3") + run_benchmark( + "Number with repeating pattern in result", + case13_mojo, + case13_py, + iterations, + log_file, + ) + + # Case 14: Number with trailing zeros (exact square) + var case14_mojo = Decimal("144.0000") + var case14_py = pydecimal.Decimal("144.0000") + run_benchmark( + "Number with trailing zeros", + case14_mojo, + case14_py, + iterations, + log_file, + ) + + # Case 15: Number slightly larger than a perfect square + var case15_mojo = Decimal("4.0001") + var case15_py = pydecimal.Decimal("4.0001") + run_benchmark( + "Slightly larger than perfect square", + case15_mojo, + case15_py, + iterations, + log_file, + ) + + # Case 16: Number slightly smaller than a perfect square + var case16_mojo = Decimal("15.9999") + var case16_py = pydecimal.Decimal("15.9999") + run_benchmark( + "Slightly smaller than perfect square", + case16_mojo, + case16_py, + iterations, + log_file, + ) + + # Case 17: Number with many decimal places + var case17_mojo = Decimal("0.12345678901234567890") + var case17_py = pydecimal.Decimal("0.12345678901234567890") + run_benchmark( + "Number with many decimal places", + case17_mojo, + case17_py, + iterations, + log_file, + ) + + # Case 18: Number close to maximum representable value + var case18_mojo = Decimal.MAX() - Decimal("1") + var case18_py = pydecimal.Decimal(String(case18_mojo)) + run_benchmark( + "Number close to maximum value", + case18_mojo, + case18_py, + iterations, + log_file, + ) + + # Case 19: Very tiny number (close to minimum positive value) + var case19_mojo = Decimal( + "0." + "0" * 27 + "1" + ) # Smallest positive decimal + var case19_py = pydecimal.Decimal(String(case19_mojo)) + run_benchmark( + "Very tiny positive number", + case19_mojo, + case19_py, + iterations, + log_file, + ) + + # Case 20: Number requiring many Newton-Raphson iterations + var case20_mojo = Decimal("987654321.123456789") + var case20_py = pydecimal.Decimal("987654321.123456789") + run_benchmark( + "Number requiring many iterations", + case20_mojo, + case20_py, + iterations, + log_file, + ) + + # Display summary + log_print("\n=== Square Root Benchmark Summary ===", log_file) + log_print("Benchmarked: 20 different square root cases", log_file) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print( + "Performance: See detailed results above for each case", log_file + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/benches/bench_subtract.mojo b/benches/bench_subtract.mojo index 9f0ed03..6046c78 100644 --- a/benches/bench_subtract.mojo +++ b/benches/bench_subtract.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal subtraction operations. Compares performance against Python's decimal module. """ -from decimojo.prelude import * +from decimojo import dm, Decimal from python import Python, PythonObject from time import perf_counter_ns import time @@ -137,9 +137,7 @@ fn main() raises: "Python decimal precision: " + String(pydecimal.getcontext().prec), log_file, ) - log_print( - "Mojo decimal precision: " + String(Decimal.MAX_PRECISION), log_file - ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) # Define benchmark cases log_print( diff --git a/mojoproject.toml b/mojoproject.toml index 991f829..b3b16c1 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -17,7 +17,7 @@ package = "magic run format && magic run mojo package src/decimojo && cp decimoj p = "clear && magic run package" # delete the package files in tests folder -delete_package = "rm tests/decimojo.mojopkg" +delete_package = "rm tests/decimojo.mojopkg && rm benches/decimojo.mojopkg" # debugs (run the testing files only) debug = "magic run package && magic run mojo tests/*.mojo && magic run delete_package" @@ -34,12 +34,11 @@ test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && ma test_round = "magic run package && magic run mojo test tests/test_rounding.mojo && magic run delete_package" # benches -bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .." +bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" b = "clear && magic run bench" - -# individual bench files -b_mul = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .." -b_div = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .." +bench_mul = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .. && magic run delete_package" +bench_div = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .. && magic run delete_package" +bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo && cd .. && magic run delete_package" # before commit final = "magic run test && magic run bench" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index ce2d14d..3d31ed4 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -5,11 +5,26 @@ # ===----------------------------------------------------------------------=== # """ -DeciMojo - Correctly-rounded, fixed-point Decimal library for Mojo. +DeciMojo: A fixed-point decimal arithmetic library in Mojo. + +You can import a list of useful objects in one line, e.g., + +```mojo +from decimojo import decimojo, dm, Decimal, D +``` + +where `decimojo` is the module itself, `dm` is an alias for the module, +`Decimal` is the `Decimal` type, and `D` is an alias for the `Decimal` type. """ +import decimojo +import decimojo as dm + from .decimal import Decimal +from .decimal import Decimal as D + from .rounding_mode import RoundingMode + from .maths import ( add, subtract, diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index 80e3331..85588d8 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -8,6 +8,18 @@ # which supports correctly-rounded, fixed-point arithmetic. # # ===----------------------------------------------------------------------=== # +# +# Organization of methods: +# - Constructors and life time methods +# - Output dunders, type-transfer dunders, and other type-transfer methods +# - Basic unary arithmetic operation dunders +# - Basic binary arithmetic operation dunders +# - Basic binary logic operation dunders +# - Other dunders that implements tratis +# - Other methods +# - Internal methods +# +# ===----------------------------------------------------------------------=== # # Docstring style: # 1. Description # 2. Parameters @@ -73,8 +85,7 @@ struct Decimal( """Scale information and the sign.""" # Constants - alias MAX_PRECISION = 28 - alias MAX_SCALE = 128 + alias MAX_SCALE: Int = 28 alias MAX_AS_UINT128 = UInt128(79228162514264337593543950335) alias MAX_AS_INT128 = Int128(79228162514264337593543950335) alias MAX_AS_UINT256 = UInt256(79228162514264337593543950335) @@ -166,7 +177,7 @@ struct Decimal( return Decimal(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, Decimal.SIGN_MASK) # ===------------------------------------------------------------------=== # - # Constructors and life time methods + # Constructors and life time dunder methods # ===------------------------------------------------------------------=== # fn __init__(out self): @@ -183,19 +194,19 @@ struct Decimal( low: UInt32, mid: UInt32, high: UInt32, - negative: Bool, scale: UInt32, + negative: Bool, ): """ - Initializes a Decimal with separate components. - the scale can be larger than 28, but will be scaled to the maximum precision. + Initializes a Decimal with five components. + If the scale is greater than MAX_SCALE, it is set to MAX_SCALE. Args: low: Least significant 32 bits of coefficient. mid: Middle 32 bits of coefficient. high: Most significant 32 bits of coefficient. - negative: True if the number is negative. scale: Number of decimal places (0-28). + negative: True if the number is negative. """ self.low = low self.mid = mid @@ -204,7 +215,7 @@ struct Decimal( # First set the flags without capping to initialize properly var flags: UInt32 = 0 - # Set the initial scale (may be higher than MAX_PRECISION) + # Set the initial scale (may be higher than MAX_SCALE) flags |= (scale << Self.SCALE_SHIFT) & Self.SCALE_MASK # Set the sign bit if negative @@ -213,14 +224,14 @@ struct Decimal( self.flags = flags - # Now check if we need to round due to exceeding MAX_PRECISION - if scale > Self.MAX_PRECISION: + # Now check if we need to round due to exceeding MAX_SCALE + if scale > Self.MAX_SCALE: # We need to properly round the value, not just change the scale - var scale_diff = scale - Self.MAX_PRECISION + var scale_diff = scale - Self.MAX_SCALE # The 'self' is already initialized above, so we can call _scale_down on it self = self._scale_down(Int(scale_diff), RoundingMode.HALF_EVEN()) - # No else needed as the value is already properly set if scale <= MAX_PRECISION + # No else needed as the value is already properly set if scale <= MAX_SCALE fn __init__( out self, low: UInt32, mid: UInt32, high: UInt32, flags: UInt32 @@ -234,7 +245,7 @@ struct Decimal( var scale = (flags & Self.SCALE_MASK) >> Self.SCALE_SHIFT # Use the previous constructor which handles scale rounding properly - self = Self(low, mid, high, is_negative, scale) + self = Self(low, mid, high, scale, is_negative) fn __init__(out self, integer: Int): """ @@ -302,25 +313,43 @@ struct Decimal( self.mid = UInt32((integer >> 32) & 0xFFFFFFFF) self.high = UInt32((integer >> 64) & 0xFFFFFFFF) - fn __init__(out self, integer: UInt128): + # TODO: Add arguments to specify the scale and sign + fn __init__(out self, integer: UInt64): """ - Initializes a Decimal from an UInt128 value. - ***WARNING***: This constructor can only handle values up to 96 bits. + Initializes a Decimal from an UInt64 value. + The `high` word will always be 0. """ self.low = UInt32(integer & 0xFFFFFFFF) self.mid = UInt32((integer >> 32) & 0xFFFFFFFF) - self.high = UInt32((integer >> 64) & 0xFFFFFFFF) + self.high = 0 self.flags = 0 - fn __init__(out self, integer: UInt256): + # TODO: Add arguments to specify the scale and sign for all integer constructors + fn __init__( + out self, integer: UInt128, scale: UInt32 = 0, negative: Bool = False + ): + """ + Initializes a Decimal from an UInt128 value. + ***WARNING***: This constructor can only handle values up to 96 bits. + """ + var low = UInt32(integer & 0xFFFFFFFF) + var mid = UInt32((integer >> 32) & 0xFFFFFFFF) + var high = UInt32((integer >> 64) & 0xFFFFFFFF) + + self = Decimal(low, mid, high, scale, negative) + + fn __init__( + out self, integer: UInt256, scale: UInt32 = 0, negative: Bool = False + ): """ Initializes a Decimal from an UInt256 value. ***WARNING***: This constructor can only handle values up to 96 bits. """ - self.low = UInt32(integer & 0xFFFFFFFF) - self.mid = UInt32((integer >> 32) & 0xFFFFFFFF) - self.high = UInt32((integer >> 64) & 0xFFFFFFFF) - self.flags = 0 + var low = UInt32(integer & 0xFFFFFFFF) + var mid = UInt32((integer >> 32) & 0xFFFFFFFF) + var high = UInt32((integer >> 64) & 0xFFFFFFFF) + + self = Decimal(low, mid, high, scale, negative) fn __init__(out self, s: String) raises: """ @@ -450,11 +479,11 @@ struct Decimal( # Move decimal point left (increase scale) scale += -exponent - # STEP 2: If scale > max_precision, + # STEP 2: If scale > MAX_SCALE, # round the coefficient string after truncating # and re-calculate the scale - if scale > Self.MAX_PRECISION: - var diff_scale = scale - Self.MAX_PRECISION + if scale > Self.MAX_SCALE: + var diff_scale = scale - Self.MAX_SCALE var kept_digits = len(string_of_coefficient) - diff_scale # Truncate the coefficient string to 29 digits @@ -473,7 +502,7 @@ struct Decimal( for i in range(len(string_of_coefficient)): result_chars.append(string_of_coefficient[i]) - var pos = Self.MAX_PRECISION + var pos = Self.MAX_SCALE while pos >= 0 and carry > 0: var digit = ord(result_chars[pos]) - ord(String("0")) digit += carry @@ -493,7 +522,7 @@ struct Decimal( for ch in result_chars: string_of_coefficient += ch[] - scale = Self.MAX_PRECISION + scale = Self.MAX_SCALE # STEP 2: Check for overflow # Check if the integral part of the coefficient is too large @@ -621,20 +650,18 @@ struct Decimal( self.flags |= Self.SIGN_MASK # TODO: Use generic floating-point type if possible. - fn __init__(out self, f: Float64, *, max_precision: Bool = True) raises: + fn __init__(out self, f: Float64, *, MAX_SCALE: Bool = True) raises: """ Initializes a Decimal from a floating-point value. You may lose precision because float representation is inexact. """ var float_str: String - if max_precision: + if MAX_SCALE: # Use maximum precision # Convert float to string ith high precision to capture all significant digits - # The format ensures we get up to MAX_PRECISION decimal places - float_str = decimojo.str._float_to_decimal_str( - f, Self.MAX_PRECISION - ) + # The format ensures we get up to MAX_SCALE decimal places + float_str = decimojo.str._float_to_decimal_str(f, Self.MAX_SCALE) else: # Use default string representation # Convert float to string with Mojo's default precision @@ -1060,6 +1087,24 @@ struct Decimal( return decimojo.round(self, 0, RoundingMode.HALF_EVEN()) + # ===------------------------------------------------------------------=== # + # Methematical methods that do not implement a trait (not a dunder) + # sqrt + # ===------------------------------------------------------------------=== # + + fn sqrt(self) raises -> Self: + """ + Calculates the square root of this Decimal. + + Returns: + The square root of this Decimal. + + Raises: + Error: If the operation would result in overflow. + """ + + return decimojo.sqrt(self) + # ===------------------------------------------------------------------=== # # Other methods # ===------------------------------------------------------------------=== # @@ -1133,18 +1178,6 @@ struct Decimal( """Returns True if this Decimal is NaN (Not a Number).""" return (self.flags & Self.NAN_MASK) != 0 - fn is_uint32able(self) -> Bool: - """ - Returns True if the coefficient can be represented as a UInt32 value. - """ - return self.high == 0 and self.mid == 0 - - fn is_uint64able(self) -> Bool: - """ - Returns True if the coefficient can be represented as a UInt64 value. - """ - return self.high == 0 - fn scale(self) -> Int: """Returns the scale (number of decimal places) of this Decimal.""" return Int((self.flags & Self.SCALE_MASK) >> Self.SCALE_SHIFT) @@ -1342,10 +1375,10 @@ struct Decimal( # Update the scale in the flags var new_scale = self.scale() + scale_diff - if new_scale > Self.MAX_PRECISION + 1: + if new_scale > Self.MAX_SCALE + 1: # Cannot scale beyond max precision, limit the scaling - scale_diff = Self.MAX_PRECISION + 1 - self.scale() - new_scale = Self.MAX_PRECISION + 1 + scale_diff = Self.MAX_SCALE + 1 - self.scale() + new_scale = Self.MAX_SCALE + 1 # With UInt128, we can represent the coefficient as a single value var coefficient = UInt128(self.high) << 64 | UInt128( diff --git a/src/decimojo/maths/__init__.mojo b/src/decimojo/maths/__init__.mojo index 5030eeb..db7309d 100644 --- a/src/decimojo/maths/__init__.mojo +++ b/src/decimojo/maths/__init__.mojo @@ -25,13 +25,12 @@ # lcm(a: Decimal, b: Decimal): Returns least common multiple of a and b # ===----------------------------------------------------------------------=== # -from .arithmetics import ( +from .basic import ( add, subtract, multiply, true_divide, - power, - absolute, - sqrt, ) +from .exp import power, sqrt from .rounding import round +from .misc import absolute diff --git a/src/decimojo/maths/arithmetics.mojo b/src/decimojo/maths/basic.mojo similarity index 81% rename from src/decimojo/maths/arithmetics.mojo rename to src/decimojo/maths/basic.mojo index 46514ac..dc37d39 100644 --- a/src/decimojo/maths/arithmetics.mojo +++ b/src/decimojo/maths/basic.mojo @@ -4,8 +4,7 @@ # https://github.com/forFudan/decimojo/blob/main/LICENSE # ===----------------------------------------------------------------------=== # # -# Implements basic object methods for the Decimal type -# which supports correctly-rounded, fixed-point arithmetic. +# Implements basic arithmetic functions for the Decimal type # # ===----------------------------------------------------------------------=== # # @@ -15,9 +14,6 @@ # subtract(x1: Decimal, x2: Decimal): Subtracts the x2 Decimal from x1 and returns a new Decimal # multiply(x1: Decimal, x2: Decimal): Multiplies two Decimal values and returns a new Decimal containing the product # true_divide(x1: Decimal, x2: Decimal): Divides x1 by x2 and returns a new Decimal containing the quotient -# power(base: Decimal, exponent: Decimal): Raises base to the power of exponent (integer exponents only) -# power(base: Decimal, exponent: Int): Convenience method for integer exponents -# sqrt(x: Decimal): Computes the square root of x using Newton-Raphson method # # ===----------------------------------------------------------------------=== # @@ -28,10 +24,6 @@ Implements functions for mathematical operations on Decimal objects. from decimojo.decimal import Decimal from decimojo.rounding_mode import RoundingMode -# ===----------------------------------------------------------------------=== # -# Binary arithmetic operations functions -# ===----------------------------------------------------------------------=== # - # TODO: Like `multiply` use combined bits to determine the appropriate method fn add(x1: Decimal, x2: Decimal) raises -> Decimal: @@ -56,8 +48,8 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: x2.low, x2.mid, x2.high, - x1.flags & x2.flags == Decimal.SIGN_MASK, max(x1.scale(), x2.scale()), + x1.flags & x2.flags == Decimal.SIGN_MASK, ) elif x2.is_zero(): @@ -65,8 +57,8 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: x1.low, x1.mid, x1.high, - x1.flags & x2.flags == Decimal.SIGN_MASK, max(x1.scale(), x2.scale()), + x1.flags & x2.flags == Decimal.SIGN_MASK, ) # Integer addition with scale of 0 (true integers) @@ -89,7 +81,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: var mid = UInt32((summation >> 32) & 0xFFFFFFFF) var high = UInt32((summation >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, x1.is_negative(), 0) + return Decimal(low, mid, high, 0, x1.is_negative()) # Different signs: subtract the smaller from the larger else: @@ -107,7 +99,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: mid = UInt32((diff >> 32) & 0xFFFFFFFF) high = UInt32((diff >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # Integer addition with positive scales elif x1.is_integer() and x2.is_integer(): @@ -137,7 +129,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: var mid = UInt32((summation >> 32) & 0xFFFFFFFF) var high = UInt32((summation >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, x1.is_negative(), scale) + return Decimal(low, mid, high, scale, x1.is_negative()) # Different signs: subtract the smaller from the larger else: @@ -166,7 +158,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: mid = UInt32((diff >> 32) & 0xFFFFFFFF) high = UInt32((diff >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, scale) + return Decimal(low, mid, high, scale, is_negative) # Float addition with the same scale elif x1.scale() == x2.scale(): @@ -204,7 +196,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: mid = UInt32((truncated_summation >> 32) & 0xFFFFFFFF) high = UInt32((truncated_summation >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, final_scale) + return Decimal(low, mid, high, final_scale, is_negative) # Float addition which with different scales else: @@ -257,7 +249,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: mid = UInt32((truncated_summation >> 32) & 0xFFFFFFFF) high = UInt32((truncated_summation >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, final_scale) + return Decimal(low, mid, high, final_scale, is_negative) fn subtract(x1: Decimal, x2: Decimal) raises -> Decimal: @@ -315,7 +307,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Return zero while preserving the scale if x1_coef == 0 or x2_coef == 0: var result = Decimal.ZERO() - var result_scale = min(combined_scale, Decimal.MAX_PRECISION) + var result_scale = min(combined_scale, Decimal.MAX_SCALE) result.flags = UInt32( (result_scale << Decimal.SCALE_SHIFT) & Decimal.SCALE_MASK ) @@ -325,17 +317,17 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: if x1_coef == 1 and x2_coef == 1: # If the combined scale exceeds the maximum precision, # return 0 with leading zeros after the decimal point and correct sign - if combined_scale > Decimal.MAX_PRECISION: + if combined_scale > Decimal.MAX_SCALE: return Decimal( 0, 0, 0, + Decimal.MAX_SCALE, is_negative, - Decimal.MAX_PRECISION, ) # Otherwise, return 1 with correct sign and scale - var final_scale = min(Decimal.MAX_PRECISION, combined_scale) - return Decimal(1, 0, 0, is_negative, final_scale) + var final_scale = min(Decimal.MAX_SCALE, combined_scale) + return Decimal(1, 0, 0, final_scale, is_negative) # SPECIAL CASE: First operand has coefficient of 1 if x1_coef == 1: @@ -351,12 +343,12 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Rounding may be needed. var num_digits_prod = decimojo.utility.number_of_digits(prod) var num_digits_to_keep = num_digits_prod - ( - combined_scale - Decimal.MAX_PRECISION + combined_scale - Decimal.MAX_SCALE ) var truncated_prod = decimojo.utility.truncate_to_digits( prod, num_digits_to_keep ) - var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + var final_scale = min(Decimal.MAX_SCALE, combined_scale) var low = UInt32(truncated_prod & 0xFFFFFFFF) var mid = UInt32((truncated_prod >> 32) & 0xFFFFFFFF) var high = UInt32((truncated_prod >> 64) & 0xFFFFFFFF) @@ -364,8 +356,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: low, mid, high, - is_negative, final_scale, + is_negative, ) # SPECIAL CASE: Second operand has coefficient of 1 @@ -382,12 +374,12 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Rounding may be needed. var num_digits_prod = decimojo.utility.number_of_digits(prod) var num_digits_to_keep = num_digits_prod - ( - combined_scale - Decimal.MAX_PRECISION + combined_scale - Decimal.MAX_SCALE ) var truncated_prod = decimojo.utility.truncate_to_digits( prod, num_digits_to_keep ) - var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + var final_scale = min(Decimal.MAX_SCALE, combined_scale) var low = UInt32(truncated_prod & 0xFFFFFFFF) var mid = UInt32((truncated_prod >> 32) & 0xFFFFFFFF) var high = UInt32((truncated_prod >> 64) & 0xFFFFFFFF) @@ -395,8 +387,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: low, mid, high, - is_negative, final_scale, + is_negative, ) # Determine the number of bits in the coefficients @@ -416,7 +408,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var prod: UInt64 = UInt64(x1.low) * UInt64(x2.low) var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) - return Decimal(low, mid, 0, is_negative, 0) + return Decimal(low, mid, 0, 0, is_negative) # Moderate integers, use UInt128 multiplication elif combined_num_bits <= 128: @@ -424,7 +416,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # Large integers, use UInt256 multiplication else: @@ -435,7 +427,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # SPECIAL CASE: Both operands are integers but with scales # Examples: 123.0 * 456.00 @@ -465,8 +457,8 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: low, mid, high, - is_negative, final_scale, + is_negative, ) # GENERAL CASES: Decimal multiplication with any scales @@ -480,24 +472,24 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var prod: UInt128 = x1_coef * x2_coef # Combined scale more than max precision, no need to truncate - if combined_scale <= Decimal.MAX_PRECISION: + if combined_scale <= Decimal.MAX_SCALE: var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, combined_scale) + return Decimal(low, mid, high, combined_scale, is_negative) # Combined scale no more than max precision, truncate with rounding else: var num_digits = decimojo.utility.number_of_digits(prod) var num_digits_to_keep = num_digits - ( - combined_scale - Decimal.MAX_PRECISION + combined_scale - Decimal.MAX_SCALE ) prod = decimojo.utility.truncate_to_digits(prod, num_digits_to_keep) - var final_scale = min(Decimal.MAX_PRECISION, combined_scale) + var final_scale = min(Decimal.MAX_SCALE, combined_scale) var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, final_scale) + return Decimal(low, mid, high, final_scale, is_negative) # SUB-CASE: Both operands are moderate # The bits of the product will not exceed 128 bits @@ -546,7 +538,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, final_scale) + return Decimal(low, mid, high, final_scale, is_negative) # REMAINING CASES: Both operands are big # The bits of the product will not exceed 192 bits @@ -596,7 +588,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, final_scale) + return Decimal(low, mid, high, final_scale, is_negative) fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: @@ -659,12 +651,12 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # SUB-CASE: divisor is 1 # If divisor is 1, return dividend with correct sign if x2_scale == 0: - return Decimal(x1.low, x1.mid, x1.high, is_negative, x1_scale) + return Decimal(x1.low, x1.mid, x1.high, x1_scale, is_negative) # SUB-CASE: divisor is of coefficient 1 with positive scale # diff_scale > 0, then final scale is diff_scale elif diff_scale > 0: - return Decimal(x1.low, x1.mid, x1.high, is_negative, diff_scale) + return Decimal(x1.low, x1.mid, x1.high, diff_scale, is_negative) # diff_scale < 0, then times 10 ** (-diff_scale) else: @@ -684,7 +676,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # If the result should be stored in UInt256 else: @@ -696,7 +688,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # SPECIAL CASE: The coefficients are equal # 特例: 係數相等 @@ -710,7 +702,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # If the scales are positive, return 1 with the difference in scales # For example, 0.1234 / 1234 = 0.0001 if diff_scale >= 0: - return Decimal(1, 0, 0, is_negative, diff_scale) + return Decimal(1, 0, 0, diff_scale, is_negative) # SUB-CASE: The scales are negative # diff_scale < 0, then times 1e-diff_scale @@ -721,7 +713,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # SPECIAL CASE: Modulus of coefficients is zero (exact division) # 特例: 係數的餘數爲零 (可除盡) @@ -740,7 +732,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, diff_scale) + return Decimal(low, mid, high, diff_scale, is_negative) else: # If diff_scale < 0, return the quotient with scaling up @@ -757,7 +749,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # If the result should be stored in UInt256 else: @@ -768,7 +760,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, 0) + return Decimal(low, mid, high, 0, is_negative) # REMAINING CASES: Perform long division # 其他情況: 進行長除法 @@ -889,7 +881,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, scale_of_quot) + return Decimal(low, mid, high, scale_of_quot, is_negative) # Otherwise, we need to truncate the first 29 or 28 digits else: @@ -909,7 +901,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var mid = UInt32((truncated_quot >> 32) & 0xFFFFFFFF) var high = UInt32((truncated_quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, scale_of_truncated_quot) + return Decimal(low, mid, high, scale_of_truncated_quot, is_negative) # SUB-CASE: Use UInt256 to store the quotient # Also the FALLBACK approach for the remaining cases @@ -964,7 +956,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var low = UInt32(quot256 & 0xFFFFFFFF) var mid = UInt32((quot256 >> 32) & 0xFFFFFFFF) var high = UInt32((quot256 >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, scale_of_quot) + return Decimal(low, mid, high, scale_of_quot, is_negative) # Otherwise, we need to truncate the first 29 or 28 digits else: @@ -989,215 +981,11 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: ) scale_of_truncated_quot -= 1 - print("DEBUG: truncated_quot", truncated_quot) - print("DEBUG: scale_of_truncated_quot", scale_of_truncated_quot) + # print("DEBUG: truncated_quot", truncated_quot) + # print("DEBUG: scale_of_truncated_quot", scale_of_truncated_quot) var low = UInt32(truncated_quot & 0xFFFFFFFF) var mid = UInt32((truncated_quot >> 32) & 0xFFFFFFFF) var high = UInt32((truncated_quot >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, is_negative, scale_of_truncated_quot) - - -fn power(base: Decimal, exponent: Decimal) raises -> Decimal: - """ - Raises base to the power of exponent and returns a new Decimal. - - Currently supports integer exponents only. - - Args: - base: The base value. - exponent: The power to raise base to. - It must be an integer or effectively an integer (e.g., 2.0). - - Returns: - A new Decimal containing the result of base^exponent - - Raises: - Error: If exponent is not an integer or if the operation would overflow. - Error: If zero is raised to a negative power. - """ - # Check if exponent is an integer - if not exponent.is_integer(): - raise Error("Power operation is only supported for integer exponents") - - # Convert exponent to integer - var exp_value = Int(exponent) - - # Special cases - if exp_value == 0: - # x^0 = 1 (including 0^0 = 1 by convention) - return Decimal.ONE() - - if exp_value == 1: - # x^1 = x - return base - - if base.is_zero(): - # 0^n = 0 for n > 0 - if exp_value > 0: - return Decimal.ZERO() - else: - # 0^n is undefined for n < 0 - raise Error("Zero cannot be raised to a negative power") - - if base.coefficient() == 1 and base.scale() == 0: - # 1^n = 1 for any n - return Decimal.ONE() - - # Handle negative exponents: x^(-n) = 1/(x^n) - var negative_exponent = exp_value < 0 - if negative_exponent: - exp_value = -exp_value - - # Binary exponentiation for efficiency - var result = Decimal.ONE() - var current_base = base - - while exp_value > 0: - if exp_value & 1: # exp_value is odd - result = result * current_base - - exp_value >>= 1 # exp_value = exp_value / 2 - - if exp_value > 0: - current_base = current_base * current_base - - # For negative exponents, take the reciprocal - if negative_exponent: - # For 1/x, use division - result = Decimal.ONE() / result - - return result - - -fn power(base: Decimal, exponent: Int) raises -> Decimal: - """ - Convenience method to raise base to an integer power. - - Args: - base: The base value. - exponent: The integer power to raise base to. - - Returns: - A new Decimal containing the result. - """ - return power(base, Decimal(exponent)) - - -# ===----------------------------------------------------------------------=== # -# Unary arithmetic operations functions -# ===----------------------------------------------------------------------=== # - - -fn absolute(x: Decimal) raises -> Decimal: - """ - Returns the absolute value of a Decimal number. - - Args: - x: The Decimal value to compute the absolute value of. - - Returns: - A new Decimal containing the absolute value of x. - """ - if x.is_negative(): - return -x - return x - - -fn sqrt(x: Decimal) raises -> Decimal: - """ - Computes the square root of a Decimal value using Newton-Raphson method. - - Args: - x: The Decimal value to compute the square root of. - - Returns: - A new Decimal containing the square root of x. - - Raises: - Error: If x is negative. - """ - # Special cases - if x.is_negative(): - raise Error("Cannot compute square root of negative number") - - if x.is_zero(): - return Decimal.ZERO() - - if x == Decimal.ONE(): - return Decimal.ONE() - - # Initial guess - a good guess helps converge faster - # For numbers near 1, use the number itself - # For very small or large numbers, scale appropriately - var guess: Decimal - var exponent = len(x.coefficient()) - x.scale() - - if exponent >= 0 and exponent <= 3: - # For numbers between 0.1 and 1000, start with x/2 + 0.5 - try: - var half_x = x / Decimal("2") - guess = half_x + Decimal("0.5") - except e: - raise e - else: - # For larger/smaller numbers, make a smarter guess - # This scales based on the magnitude of the number - var shift: Int - if exponent % 2 != 0: - # For odd exponents, adjust - shift = (exponent + 1) // 2 - else: - shift = exponent // 2 - - try: - # Use an approximation based on the exponent - if exponent > 0: - guess = Decimal("10") ** shift - else: - guess = Decimal("0.1") ** (-shift) - - except e: - raise e - - # Newton-Raphson iterations - # x_n+1 = (x_n + S/x_n) / 2 - var prev_guess = Decimal.ZERO() - var iteration_count = 0 - var max_iterations = 100 # Prevent infinite loops - - while guess != prev_guess and iteration_count < max_iterations: - # print("------------------------------------------------------") - # print("DEBUG: iteration_count", iteration_count) - # print("DEBUG: prev_guess", prev_guess) - # print("DEBUG: guess", guess) - - prev_guess = guess - - try: - var division_result = x / guess - var sum_result = guess + division_result - guess = sum_result / Decimal("2") - except e: - raise e - - iteration_count += 1 - - # If exact square root found - # Remove trailing zeros after the decimal point - var guess_coef = guess.coefficient() - var count = 0 - for _ in range(guess.scale()): - if guess_coef % 10 == 0: - guess_coef //= 10 - count += 1 - else: - break - if guess_coef * guess_coef == x.coefficient(): - var low = UInt32(guess_coef & 0xFFFFFFFF) - var mid = UInt32((guess_coef >> 32) & 0xFFFFFFFF) - var high = UInt32((guess_coef >> 64) & 0xFFFFFFFF) - return Decimal(low, mid, high, False, guess.scale() - count) - - return guess + return Decimal(low, mid, high, scale_of_truncated_quot, is_negative) diff --git a/src/decimojo/maths/exp.mojo b/src/decimojo/maths/exp.mojo index 2b25fe1..7465916 100644 --- a/src/decimojo/maths/exp.mojo +++ b/src/decimojo/maths/exp.mojo @@ -4,6 +4,227 @@ # https://github.com/forFudan/decimojo/blob/main/LICENSE # ===----------------------------------------------------------------------=== # # -# Implements exponential and logarithmic functions for the Decimal type +# Implements exponential functions for the Decimal type # # ===----------------------------------------------------------------------=== # +# +# List of functions in this module: +# +# power(base: Decimal, exponent: Decimal): Raises base to the power of exponent (integer exponents only) +# power(base: Decimal, exponent: Int): Convenience method for integer exponents +# sqrt(x: Decimal): Computes the square root of x using Newton-Raphson method +# +# ===----------------------------------------------------------------------=== # + +import math as builtin_math +import testing + + +fn power(base: Decimal, exponent: Decimal) raises -> Decimal: + """ + Raises base to the power of exponent and returns a new Decimal. + + Currently supports integer exponents only. + + Args: + base: The base value. + exponent: The power to raise base to. + It must be an integer or effectively an integer (e.g., 2.0). + + Returns: + A new Decimal containing the result of base^exponent + + Raises: + Error: If exponent is not an integer or if the operation would overflow. + Error: If zero is raised to a negative power. + """ + # Check if exponent is an integer + if not exponent.is_integer(): + raise Error("Power operation is only supported for integer exponents") + + # Convert exponent to integer + var exp_value = Int(exponent) + + # Special cases + if exp_value == 0: + # x^0 = 1 (including 0^0 = 1 by convention) + return Decimal.ONE() + + if exp_value == 1: + # x^1 = x + return base + + if base.is_zero(): + # 0^n = 0 for n > 0 + if exp_value > 0: + return Decimal.ZERO() + else: + # 0^n is undefined for n < 0 + raise Error("Zero cannot be raised to a negative power") + + if base.coefficient() == 1 and base.scale() == 0: + # 1^n = 1 for any n + return Decimal.ONE() + + # Handle negative exponents: x^(-n) = 1/(x^n) + var negative_exponent = exp_value < 0 + if negative_exponent: + exp_value = -exp_value + + # Binary exponentiation for efficiency + var result = Decimal.ONE() + var current_base = base + + while exp_value > 0: + if exp_value & 1: # exp_value is odd + result = result * current_base + + exp_value >>= 1 # exp_value = exp_value / 2 + + if exp_value > 0: + current_base = current_base * current_base + + # For negative exponents, take the reciprocal + if negative_exponent: + # For 1/x, use division + result = Decimal.ONE() / result + + return result + + +fn power(base: Decimal, exponent: Int) raises -> Decimal: + """ + Convenience method to raise base to an integer power. + + Args: + base: The base value. + exponent: The integer power to raise base to. + + Returns: + A new Decimal containing the result. + """ + return power(base, Decimal(exponent)) + + +fn sqrt(x: Decimal) raises -> Decimal: + """ + Computes the square root of a Decimal value using Newton-Raphson method. + + Args: + x: The Decimal value to compute the square root of. + + Returns: + A new Decimal containing the square root of x. + + Raises: + Error: If x is negative. + """ + # Special cases + if x.is_negative(): + raise Error( + "Error in sqrt: Cannot compute square root of a negative number" + ) + + if x.is_zero(): + return Decimal.ZERO() + + if x == Decimal.ONE(): + return Decimal.ONE() + + var x_coef: UInt128 = x.coefficient() + var x_scale = x.scale() + + # Initial guess - a good guess helps converge faster + # use floating point approach to quickly find a good guess + + var guess: Decimal + + # For numbers with zero scale (true integers) + if x_scale == 0: + var float_sqrt = builtin_math.sqrt(Float64(x_coef)) + guess = Decimal(UInt128(float_sqrt)) + # print("DEBUG: scale = 0") + + # For numbers with even scale + elif x_scale % 2 == 0: + var float_sqrt = builtin_math.sqrt(Float64(x_coef)) + guess = Decimal(UInt128(float_sqrt), negative=False, scale=x_scale >> 1) + # print("DEBUG: scale is even") + + # For numbers with odd scale + else: + var float_sqrt = builtin_math.sqrt(Float64(x_coef)) * Float64(3.15625) + guess = Decimal( + UInt128(float_sqrt), negative=False, scale=(x_scale + 1) >> 1 + ) + # print("DEBUG: scale is odd") + + # print("DEBUG: initial guess", guess) + + # print("DEBUG: initial guess", guess) + testing.assert_false(guess.is_zero(), "Initial guess should not be zero") + + # Newton-Raphson iterations + # x_n+1 = (x_n + S/x_n) / 2 + var prev_guess = Decimal.ZERO() + var iteration_count = 0 + + # Iterate until guess converges or max iterations reached + # max iterations is set to 100 to avoid infinite loop + # log2(1e18) ~= 60, so 100 iterations should be enough + while guess != prev_guess and iteration_count < 100: + prev_guess = guess + var division_result = x / guess + var sum_result = guess + division_result + guess = sum_result / Decimal(2, 0, 0, 0, False) + iteration_count += 1 + + # print("------------------------------------------------------") + # print("DEBUG: iteration_count", iteration_count) + # print("DEBUG: prev guess", prev_guess) + # print("DEBUG: new guess ", guess) + + # print("DEBUG: iteration_count", iteration_count) + + # If exact square root found remove trailing zeros after the decimal point + # For example, sqrt(81) = 9, not 9.000000 + # For example, sqrt(100.0000) = 10.00 not 10.000000 + # Exact square means that the coefficient of guess after removing trailing zeros + # is equal to the coefficient of x + + var guess_coef = guess.coefficient() + + # No need to do this if the last digit of the coefficient of guess is not zero + if guess_coef % 10 == 0: + var num_digits_x_ceof = decimojo.utility.number_of_digits(x_coef) + var num_digits_x_sqrt_coef = (num_digits_x_ceof + 1) >> 1 + var num_digits_guess_coef = decimojo.utility.number_of_digits( + guess_coef + ) + var num_digits_to_decrease = num_digits_guess_coef - num_digits_x_sqrt_coef + + testing.assert_true( + num_digits_to_decrease >= 0, + "sqrt of x has fewer digits than expected", + ) + for _ in range(num_digits_to_decrease): + if guess_coef % 10 == 0: + guess_coef //= 10 + else: + break + else: + # print("DEBUG: guess", guess) + # print("DEBUG: guess_coef after removing trailing zeros", guess_coef) + if guess_coef * guess_coef == x_coef: + var low = UInt32(guess_coef & 0xFFFFFFFF) + var mid = UInt32((guess_coef >> 32) & 0xFFFFFFFF) + var high = UInt32((guess_coef >> 64) & 0xFFFFFFFF) + return Decimal( + low, + mid, + high, + guess.scale() - num_digits_to_decrease, + False, + ) + + return guess diff --git a/src/decimojo/maths/misc.mojo b/src/decimojo/maths/misc.mojo new file mode 100644 index 0000000..5a6cea9 --- /dev/null +++ b/src/decimojo/maths/misc.mojo @@ -0,0 +1,24 @@ +# ===----------------------------------------------------------------------=== # +# Distributed under the Apache 2.0 License with LLVM Exceptions. +# See LICENSE and the LLVM License for more information. +# https://github.com/forFudan/decimojo/blob/main/LICENSE +# ===----------------------------------------------------------------------=== # +# +# Implements miscellaneous mathematical functions for the Decimal type +# +# ===----------------------------------------------------------------------=== # + + +fn absolute(x: Decimal) raises -> Decimal: + """ + Returns the absolute value of a Decimal number. + + Args: + x: The Decimal value to compute the absolute value of. + + Returns: + A new Decimal containing the absolute value of x. + """ + if x.is_negative(): + return -x + return x diff --git a/src/decimojo/prelude.mojo b/src/decimojo/prelude.mojo deleted file mode 100644 index f377ebd..0000000 --- a/src/decimojo/prelude.mojo +++ /dev/null @@ -1,18 +0,0 @@ -""" -prelude -======= - -tries to find out a balance by providing a list of things that can be imported -at one time. The list contains the functions or types that are the most -essential for a user. - -You can use the following code to import them: - -```mojo -from decimojo.prelude import * -``` -""" - -import decimojo as dm -from decimojo import Decimal -from decimojo import RoundingMode diff --git a/src/decimojo/utility.mojo b/src/decimojo/utility.mojo index 3ed8431..82f4cfa 100644 --- a/src/decimojo/utility.mojo +++ b/src/decimojo/utility.mojo @@ -14,10 +14,12 @@ from memory import UnsafePointer from decimojo.decimal import Decimal +# UNSAFE fn bitcast[dtype: DType](dec: Decimal) -> Scalar[dtype]: """ Direct memory bit copy from Decimal (low, mid, high) to Mojo's Scalar type. This performs a bitcast/reinterpretation rather than bit manipulation. + ***UNSAFE***: This function is unsafe and should be used with caution. Parameters: dtype: The Mojo scalar type to bitcast to. @@ -26,7 +28,7 @@ fn bitcast[dtype: DType](dec: Decimal) -> Scalar[dtype]: dec: The Decimal to bitcast. Constraints: - `dtype` must be either `DType.uint128` or `DType.uint256`. + `dtype` must be `DType.uint128`. Returns: The bitcasted Decimal (low, mid, high) as a Mojo scalar. @@ -35,8 +37,8 @@ fn bitcast[dtype: DType](dec: Decimal) -> Scalar[dtype]: # Compile-time checker: ensure the dtype is either uint128 or uint256 constrained[ - dtype == DType.uint128 or dtype == DType.uint256, - "must be uint128 or uint256", + dtype == DType.uint128, + "must be uint128", ]() # Bitcast the Decimal to the desired Mojo scalar type @@ -199,7 +201,7 @@ fn truncate_to_digits[ When you want to apply a scale of 31 to the coefficient `997`, it will be `0.0000000000000000000000000000997` with 31 digits. However, we can only - store 28 digits in the coefficient (Decimal.MAX_PRECISION = 28). + store 28 digits in the coefficient (Decimal.MAX_SCALE = 28). Therefore, we need to truncate the coefficient to 0 (`3 - (31 - 28)`) digits and round it to the nearest even number. The truncated ceofficient will be `1`. @@ -209,7 +211,7 @@ fn truncate_to_digits[ When you want to apply a scale of 29 to the coefficient `234567`, it will be `0.00000000000000000000000234567` with 29 digits. However, we can only - store 28 digits in the coefficient (Decimal.MAX_PRECISION = 28). + store 28 digits in the coefficient (Decimal.MAX_SCALE = 28). Therefore, we need to truncate the coefficient to 5 (`6 - (29 - 28)`) digits and round it to the nearest even number. The truncated ceofficient will be `23457`. diff --git a/tests/test_arithmetics.mojo b/tests/test_arithmetics.mojo index 76db1f5..61ac158 100644 --- a/tests/test_arithmetics.mojo +++ b/tests/test_arithmetics.mojo @@ -2,7 +2,7 @@ Test Decimal arithmetic operations including addition, subtraction, and negation. """ -from decimojo.prelude import * +from decimojo import dm, Decimal import testing @@ -831,7 +831,7 @@ fn test_extreme_cases() raises: try: var a2 = Decimal("79228162514264337593543950335") # MAX() var b2 = Decimal("1") - var result2 = a2 + b2 + var _result2 = a2 + b2 print("WARNING: Addition beyond MAX() didn't raise an error") except: print("Addition overflow correctly detected") diff --git a/tests/test_conversions.mojo b/tests/test_conversions.mojo index 1744e95..2b95a21 100644 --- a/tests/test_conversions.mojo +++ b/tests/test_conversions.mojo @@ -3,7 +3,7 @@ Test Decimal conversion methods: __int__, __float__, and __str__ for different numerical cases. """ -from decimojo.prelude import * +from decimojo import dm, Decimal import testing import time diff --git a/tests/test_creation.mojo b/tests/test_creation.mojo index 30380b2..2bcd15c 100644 --- a/tests/test_creation.mojo +++ b/tests/test_creation.mojo @@ -287,7 +287,7 @@ fn test_decimal_from_string() raises: print("Decimal point without digits not supported") var only_zeros = Decimal("0.0000") - var max_precision = Decimal("0." + "9" * 28) + var MAX_SCALE = Decimal("0." + "9" * 28) testing.assert_equal( String(only_zeros), @@ -295,7 +295,7 @@ fn test_decimal_from_string() raises: "String of zeros should be represented as '0'", ) testing.assert_equal( - String(max_precision), + String(MAX_SCALE), "0." + "9" * 28, "Max precision should be preserved", ) @@ -329,38 +329,38 @@ fn test_decimal_from_components() raises: print("Testing Decimal Creation from Components") # Test case 1: Zero with zero scale - var zero = Decimal(0, 0, 0, False, 0) + var zero = Decimal(0, 0, 0, 0, False) testing.assert_equal(String(zero), "0", "Zero with scale 0") # Test case 2: One with zero scale - var one = Decimal(1, 0, 0, False, 0) + var one = Decimal(1, 0, 0, 0, False) testing.assert_equal(String(one), "1", "One with scale 0") # Test case 3: Negative one - var neg_one = Decimal(1, 0, 0, True, 0) + var neg_one = Decimal(1, 0, 0, 0, True) testing.assert_equal(String(neg_one), "-1", "Negative one") # Test case 4: Simple number with scale - var with_scale = Decimal(12345, 0, 0, False, 2) + var with_scale = Decimal(12345, 0, 0, 2, False) testing.assert_equal( String(with_scale), "123.45", "Simple number with scale 2" ) # Test case 5: Negative number with scale - var neg_with_scale = Decimal(12345, 0, 0, True, 2) + var neg_with_scale = Decimal(12345, 0, 0, 2, True) testing.assert_equal( String(neg_with_scale), "-123.45", "Negative number with scale 2" ) # Test case 6: Larger number using mid - var large = Decimal(0xFFFFFFFF, 5, 0, False, 0) + var large = Decimal(0xFFFFFFFF, 5, 0, 0, False) var expected_large = Decimal(String(0xFFFFFFFF + 5 * 4294967296)) testing.assert_equal( String(large), String(expected_large), "Large number using mid field" ) # Test case 7: Verify scale is correctly stored - var high_scale = Decimal(123, 0, 0, False, 10) + var high_scale = Decimal(123, 0, 0, 10, False) testing.assert_equal( high_scale.scale(), 10, "Scale should be correctly stored" ) @@ -369,7 +369,7 @@ fn test_decimal_from_components() raises: ) # Test case 8: Test large scale with negative number - var neg_high_scale = Decimal(123, 0, 0, True, 10) + var neg_high_scale = Decimal(123, 0, 0, 10, True) testing.assert_equal( String(neg_high_scale), "-0.0000000123", @@ -386,7 +386,7 @@ fn test_decimal_from_components() raises: ) # Test case 10: With high component - var with_high = Decimal(0, 0, 3, False, 0) + var with_high = Decimal(0, 0, 3, 0, False) testing.assert_equal( String(with_high), "55340232221128654848", @@ -394,11 +394,11 @@ fn test_decimal_from_components() raises: ) # Test case 11: Maximum possible scale - var max_scale = Decimal(123, 0, 0, False, 28) + var max_scale = Decimal(123, 0, 0, 28, False) testing.assert_equal(max_scale.scale(), 28, "Maximum scale should be 28") # Test case 12: Overflow scale protection - var overflow_scale = Decimal(123, 0, 0, False, 100) + var overflow_scale = Decimal(123, 0, 0, 100, False) testing.assert_true( overflow_scale.scale() <= 28, "Scale should be capped to max precision" ) diff --git a/tests/test_division.mojo b/tests/test_division.mojo index 7309396..a66e627 100644 --- a/tests/test_division.mojo +++ b/tests/test_division.mojo @@ -3,7 +3,7 @@ Comprehensive test suite for Decimal division operations. Includes 100 test cases covering edge cases, precision limits, and various scenarios. """ -from decimojo.prelude import * +from decimojo import dm, Decimal import testing @@ -191,10 +191,8 @@ fn test_precision_rounding() raises: # 25. Precision limit with repeating 9s var a25 = Decimal("1") / Decimal("81") # ~0.01234... - var precision_reached = a25.scale() <= Decimal.MAX_PRECISION - testing.assert_true( - precision_reached, "Scale should not exceed MAX_PRECISION" - ) + var precision_reached = a25.scale() <= Decimal.MAX_SCALE + testing.assert_true(precision_reached, "Scale should not exceed MAX_SCALE") # 26. Test precision with negative numbers var a26 = Decimal("-1") / Decimal("3") @@ -222,15 +220,15 @@ fn test_precision_rounding() raises: # 29. Division where quotient has more digits than precision allows var a29 = Decimal("12345678901234567890123456789") / Decimal("7") testing.assert_true( - a29.scale() <= Decimal.MAX_PRECISION, - "Scale should not exceed MAX_PRECISION", + a29.scale() <= Decimal.MAX_SCALE, + "Scale should not exceed MAX_SCALE", ) # 30. Division where both operands have maximum precision var a30 = Decimal("0." + "1" * 28) / Decimal("0." + "9" * 28) testing.assert_true( - a30.scale() <= Decimal.MAX_PRECISION, - "Scale should not exceed MAX_PRECISION", + a30.scale() <= Decimal.MAX_SCALE, + "Scale should not exceed MAX_SCALE", ) print("✓ Precision and rounding tests passed!") @@ -338,7 +336,7 @@ fn test_edge_cases() raises: var max_decimal = Decimal.MAX() var small_divisor = Decimal("0.0001") try: - var a43 = max_decimal / small_divisor + var _a43 = max_decimal / small_divisor except: print( "Division of very large number by very small number raised" @@ -350,7 +348,7 @@ fn test_edge_cases() raises: "0." + "0" * 27 + "1" ) # Smallest positive decimal var a44 = min_positive / Decimal("2") - testing.assert_true(a44.scale() <= Decimal.MAX_PRECISION) + testing.assert_true(a44.scale() <= Decimal.MAX_SCALE) # 45. Division by power of 2 (binary divisions) testing.assert_equal( @@ -364,11 +362,11 @@ fn test_edge_cases() raises: "Division by 9's", ) - # 47. Division resulting in exactly MAX_PRECISION digits + # 47. Division resulting in exactly MAX_SCALE digits var a47 = Decimal("1") / Decimal("3") testing.assert_true( - a47.scale() == Decimal.MAX_PRECISION, - "Case 47: Division resulting in exactly MAX_PRECISION digits failed", + a47.scale() == Decimal.MAX_SCALE, + "Case 47: Division resulting in exactly MAX_SCALE digits failed", ) # 48. Division of large integers resulting in max precision @@ -386,7 +384,7 @@ fn test_edge_cases() raises: # 50. Division with value at maximum supported scale var a50 = Decimal("0." + "0" * 27 + "5") / Decimal("1") testing.assert_true( - a50.scale() <= Decimal.MAX_PRECISION, + a50.scale() <= Decimal.MAX_SCALE, "Case 50: Division with value at maximum supported scale failed", ) @@ -643,9 +641,9 @@ fn test_rounding_behavior() raises: # 81. Banker's rounding at boundary (round to even) var a81 = Decimal("1") / Decimal( - String("3" + "0" * (Decimal.MAX_PRECISION - 1)) + String("3" + "0" * (Decimal.MAX_SCALE - 1)) ) - var expected = "0." + "0" * (Decimal.MAX_PRECISION - 1) + "3" + var expected = "0." + "0" * (Decimal.MAX_SCALE - 1) + "3" testing.assert_equal( String(a81), expected, "Case 81: Banker's rounding at boundary failed" ) @@ -667,8 +665,8 @@ fn test_rounding_behavior() raises: ) # 84. Division that results in exactly half a unit in last place - var a84 = Decimal("1") / Decimal("4" + "0" * Decimal.MAX_PRECISION) - var expected84 = Decimal("0." + "0" * (Decimal.MAX_PRECISION)) + var a84 = Decimal("1") / Decimal("4" + "0" * Decimal.MAX_SCALE) + var expected84 = Decimal("0." + "0" * (Decimal.MAX_SCALE)) testing.assert_equal( a84, expected84, @@ -717,13 +715,13 @@ fn test_rounding_behavior() raises: "Testing half-even rounding with odd digit before", ) - # 90. Division with MAX_PRECISION-3 digits + # 90. Division with MAX_SCALE-3 digits # 1 / 300000000000000000000000000 (26 zeros) var a90 = Decimal("1") / Decimal(String("300000000000000000000000000")) testing.assert_equal( String(a90), "0.0000000000000000000000000033", - "Case 90: Division with exactly MAX_PRECISION digits failed", + "Case 90: Division with exactly MAX_SCALE digits failed", ) print("✓ Rounding behavior tests passed!") @@ -735,7 +733,7 @@ fn test_error_cases() raises: # 91. Division by zero try: - var result = Decimal("123") / Decimal("0") + var _result = Decimal("123") / Decimal("0") testing.assert_true( False, "Case 91: Expected division by zero to raise exception" ) @@ -810,7 +808,7 @@ fn test_error_cases() raises: # 100. Division at the exact boundary of precision limit # 1 / 70000000000000000000000000000 (28 zeros) - var a100 = Decimal("1") / Decimal(String("7" + "0" * Decimal.MAX_PRECISION)) + var a100 = Decimal("1") / Decimal(String("7" + "0" * Decimal.MAX_SCALE)) testing.assert_equal( String(a100), "0.0000000000000000000000000000", diff --git a/tests/test_rounding.mojo b/tests/test_rounding.mojo index a262c93..b7e2faf 100644 --- a/tests/test_rounding.mojo +++ b/tests/test_rounding.mojo @@ -1,7 +1,7 @@ """ Test Decimal rounding methods with different rounding modes and precision levels. """ -from decimojo.prelude import * +from decimojo import dm, Decimal, RoundingMode import testing @@ -140,9 +140,9 @@ fn test_edge_cases() raises: ) # Test case 5: Rounding to maximum precision - var max_precision = Decimal("0." + "1" * 28) # 0.1111...1 (28 digits) + var MAX_SCALE = Decimal("0." + "1" * 28) # 0.1111...1 (28 digits) testing.assert_equal( - String(round(max_precision, 14)), + String(round(MAX_SCALE, 14)), "0.11111111111111", "Rounding from maximum precision", ) diff --git a/tests/test_sqrt.mojo b/tests/test_sqrt.mojo index 672a037..f5404b7 100644 --- a/tests/test_sqrt.mojo +++ b/tests/test_sqrt.mojo @@ -1,7 +1,7 @@ """ Comprehensive tests for the sqrt function of the Decimal type. """ -from decimojo.prelude import * +from decimojo import dm, Decimal from decimojo import sqrt import testing diff --git a/tests/test_utility.mojo b/tests/test_utility.mojo index 5a4a652..b9eaf13 100644 --- a/tests/test_utility.mojo +++ b/tests/test_utility.mojo @@ -5,7 +5,7 @@ Tests for the utility functions in the decimojo.utility module. from testing import assert_equal, assert_true import max -from decimojo.prelude import * +from decimojo import dm, Decimal from decimojo.utility import truncate_to_max, number_of_digits