diff --git a/README.md b/README.md index 3bbe935..3014813 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,7 @@ For brevity, you can refer to it as "deci" (derived from the Latin root "decimus When you add `from decimojo import dm, Decimal` at the top of your script, this imports the `decimojo` module into your namespace with the shorter alias `dm` and directly imports the `Decimal` type. This is equivalent to: ```mojo -import decimojo as dm -from decimojo import Decimal +from decimojo.prelude import dm, Decimal, RoundingMode ``` ## Advantages @@ -146,7 +145,7 @@ print(precise) # 0.1234567890123456789012345678 # Truncation to specific number of digits var large_num = Decimal("123456.789") -print(truncate_to_digits(large_num, 4)) # 1235 (banker's rounded) +print(round_to_keep_first_n_digits(large_num, 4)) # 1235 (banker's rounded) ``` ### 4. Sign Handling and Absolute Value diff --git a/benches/bench_add.mojo b/benches/bench_add.mojo index 681614e..d584762 100644 --- a/benches/bench_add.mojo +++ b/benches/bench_add.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal addition operations. Compares performance against Python's decimal module. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from python import Python, PythonObject from time import perf_counter_ns import time diff --git a/benches/bench_divide.mojo b/benches/bench_divide.mojo index 3add59a..5b211fc 100644 --- a/benches/bench_divide.mojo +++ b/benches/bench_divide.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal division operations. Compares performance against Python's decimal module. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from python import Python, PythonObject from time import perf_counter_ns import time diff --git a/benches/bench_multiply.mojo b/benches/bench_multiply.mojo index 09281ae..7810339 100644 --- a/benches/bench_multiply.mojo +++ b/benches/bench_multiply.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal multiplication operations. Compares performance against Python's decimal module. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from python import Python, PythonObject from time import perf_counter_ns import time diff --git a/benches/bench_round.mojo b/benches/bench_round.mojo new file mode 100644 index 0000000..d394f81 --- /dev/null +++ b/benches/bench_round.mojo @@ -0,0 +1,606 @@ +""" +Comprehensive benchmarks for Decimal rounding operations. +Compares performance against Python's decimal module across 32 diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_round_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark( + name: String, + d_mojo: Decimal, + places: Int, + d_py: PythonObject, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal round with Python Decimal quantize. + + Args: + name: Name of the benchmark case. + d_mojo: Mojo Decimal operand. + places: Number of decimal places to round to. + d_py: Python Decimal operand. + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Decimal: " + String(d_mojo), log_file) + log_print("Round to: " + String(places) + " places", log_file) + + # Get Python decimal module for quantize operation + var py = Python.import_module("builtins") + + # Execute the operations once to verify correctness + var mojo_result = round(d_mojo, places) + var py_result = py.round(d_py, ndigits=places) + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = round(d_mojo, places) + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py.round(d_py, ndigits=places) + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo Decimal: " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python Decimal: " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Rounding Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 1000 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning rounding benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Standard rounding to 2 decimal places + var case1_mojo = Decimal("123.456789") + var case1_py = pydecimal.Decimal("123.456789") + run_benchmark( + "Standard rounding to 2 places", + case1_mojo, + 2, + case1_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Banker's rounding (round half to even) for .5 with even preceding digit + var case2_mojo = Decimal("10.125") + var case2_py = pydecimal.Decimal("10.125") + run_benchmark( + "Banker's rounding with even preceding digit", + case2_mojo, + 2, + case2_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Banker's rounding (round half to even) for .5 with odd preceding digit + var case3_mojo = Decimal("10.135") + var case3_py = pydecimal.Decimal("10.135") + run_benchmark( + "Banker's rounding with odd preceding digit", + case3_mojo, + 2, + case3_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Rounding with less than half (<0.5) + var case4_mojo = Decimal("10.124") + var case4_py = pydecimal.Decimal("10.124") + run_benchmark( + "Rounding with less than half", + case4_mojo, + 2, + case4_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Rounding with more than half (>0.5) + var case5_mojo = Decimal("10.126") + var case5_py = pydecimal.Decimal("10.126") + run_benchmark( + "Rounding with more than half", + case5_mojo, + 2, + case5_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 6: Rounding to 0 decimal places (whole number) + var case6_mojo = Decimal("123.456") + var case6_py = pydecimal.Decimal("123.456") + run_benchmark( + "Rounding to whole number", + case6_mojo, + 0, + case6_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 7: Rounding a negative number + var case7_mojo = Decimal("-123.456") + var case7_py = pydecimal.Decimal("-123.456") + run_benchmark( + "Rounding negative number", + case7_mojo, + 2, + case7_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 8: Rounding to negative places (tens) + var case8_mojo = Decimal("123.456") + var case8_py = pydecimal.Decimal("123.456") + run_benchmark( + "Rounding to tens (negative places)", + case8_mojo, + -1, + case8_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 9: Rounding to negative places (hundreds) + var case9_mojo = Decimal("1234.56") + var case9_py = pydecimal.Decimal("1234.56") + run_benchmark( + "Rounding to hundreds (negative places)", + case9_mojo, + -2, + case9_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Rounding a very small number + var case10_mojo = Decimal("0.0000001234") + var case10_py = pydecimal.Decimal("0.0000001234") + run_benchmark( + "Rounding very small number", + case10_mojo, + 10, + case10_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 11: Rounding already rounded number (no change) + var case11_mojo = Decimal("123.45") + var case11_py = pydecimal.Decimal("123.45") + run_benchmark( + "Rounding already rounded number", + case11_mojo, + 2, + case11_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 12: Rounding to high precision (20 places) + var case12_mojo = Decimal("0.12345678901234567890123") + var case12_py = pydecimal.Decimal("0.12345678901234567890123") + run_benchmark( + "Rounding to high precision (20 places)", + case12_mojo, + 20, + case12_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Rounding to more places than input has + var case13_mojo = Decimal("123.456") + var case13_py = pydecimal.Decimal("123.456") + run_benchmark( + "Rounding to more places than input (10)", + case13_mojo, + 10, + case13_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Rounding number with trailing 9's + var case14_mojo = Decimal("9.999") + var case14_py = pydecimal.Decimal("9.999") + run_benchmark( + "Rounding number with trailing 9's", + case14_mojo, + 2, + case14_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 15: Rounding requiring carry propagation (9.99 -> 10.0) + var case15_mojo = Decimal("9.99") + var case15_py = pydecimal.Decimal("9.99") + run_benchmark( + "Rounding requiring carry propagation", + case15_mojo, + 1, + case15_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 16: Rounding exactly half with even preceding digit + var case16_mojo = Decimal("2.5") + var case16_py = pydecimal.Decimal("2.5") + run_benchmark( + "Rounding exactly half with even preceding digit", + case16_mojo, + 0, + case16_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 17: Rounding exactly half with odd preceding digit + var case17_mojo = Decimal("3.5") + var case17_py = pydecimal.Decimal("3.5") + run_benchmark( + "Rounding exactly half with odd preceding digit", + case17_mojo, + 0, + case17_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 18: Rounding value close to MAX + var case18_mojo = Decimal("12345678901234567890") - Decimal("0.12345") + var case18_py = pydecimal.Decimal(String(case18_mojo)) + run_benchmark( + "Rounding value close to MAX", + case18_mojo, + 2, + case18_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 19: Rounding minimum positive value + var case19_mojo = Decimal( + "0." + "0" * 27 + "1" + ) # Smallest positive decimal + var case19_py = pydecimal.Decimal(String(case19_mojo)) + run_benchmark( + "Rounding minimum positive value", + case19_mojo, + 28, + case19_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 20: Rounding perfectly formatted engineering value + var case20_mojo = Decimal("123456.789e-3") # 123.456789 + var case20_py = pydecimal.Decimal("123456.789e-3") + run_benchmark( + "Rounding engineering notation value", + case20_mojo, + 4, + case20_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 21: Rounding a number with long integer part + var case21_mojo = Decimal("12345678901234567.8901") + var case21_py = pydecimal.Decimal("12345678901234567.8901") + run_benchmark( + "Rounding number with long integer part", + case21_mojo, + 2, + case21_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 22: Rounding a number that's just below half + var case22_mojo = Decimal("10.12499999999999999") + var case22_py = pydecimal.Decimal("10.12499999999999999") + run_benchmark( + "Rounding number just below half", + case22_mojo, + 2, + case22_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 23: Rounding a number that's just above half + var case23_mojo = Decimal("10.12500000000000001") + var case23_py = pydecimal.Decimal("10.12500000000000001") + run_benchmark( + "Rounding number just above half", + case23_mojo, + 2, + case23_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 24: Rounding to max precision (28 places) + var case24_mojo = Decimal("0." + "1" * 29) # More digits than max precision + var case24_py = pydecimal.Decimal(String(case24_mojo)) + run_benchmark( + "Rounding to maximum precision (28)", + case24_mojo, + 28, + case24_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 25: Rounding zero + var case25_mojo = Decimal("0") + var case25_py = pydecimal.Decimal("0") + run_benchmark( + "Rounding zero", + case25_mojo, + 10, + case25_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 26: Rounding a series of mixed digits + var case26_mojo = Decimal("3.141592653589793238462643383") + var case26_py = pydecimal.Decimal("3.141592653589793238462643383") + run_benchmark( + "Rounding Pi to various places", + case26_mojo, + 15, + case26_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 27: Rounding negative number requiring carry propagation + var case27_mojo = Decimal("-9.99") + var case27_py = pydecimal.Decimal("-9.99") + run_benchmark( + "Rounding negative with carry propagation", + case27_mojo, + 1, + case27_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 28: Rounding with exact .5 and zeros after + var case28_mojo = Decimal("1.5000000000000000000") + var case28_py = pydecimal.Decimal("1.5000000000000000000") + run_benchmark( + "Rounding with exact .5 and zeros after", + case28_mojo, + 0, + case28_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 29: Rounding with exact .5 and non-zeros after + var case29_mojo = Decimal("1.50000000000000000001") + var case29_py = pydecimal.Decimal("1.50000000000000000001") + run_benchmark( + "Rounding with exact .5 and non-zeros after", + case29_mojo, + 0, + case29_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 30: Rounding extremely close to MAX value + var case30_mojo = Decimal("123456789012345678") - Decimal("0.000000001") + var case30_py = pydecimal.Decimal(String(case30_mojo)) + run_benchmark( + "Rounding extremely close to MAX", + case30_mojo, + 8, + case30_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 31: Random decimal with various digits after decimal + var case31_mojo = Decimal("7.389465718934026719043") + var case31_py = pydecimal.Decimal("7.389465718934026719043") + run_benchmark( + "Random decimal with various digits", + case31_mojo, + 12, + case31_py, + iterations, + log_file, + speedup_factors, + ) + + # Case 32: Number with alternating digits + var case32_mojo = Decimal("1.010101010101010101010101") + var case32_py = pydecimal.Decimal("1.010101010101010101010101") + run_benchmark( + "Number with alternating digits", + case32_mojo, + 18, + case32_py, + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== Rounding Benchmark Summary ===", log_file) + log_print("Benchmarked: 32 different rounding cases", log_file) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print("Average speedup: " + String(average_speedup) + "×", log_file) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)), + log_file, + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/benches/bench_sqrt.mojo b/benches/bench_sqrt.mojo index 0387938..2209ade 100644 --- a/benches/bench_sqrt.mojo +++ b/benches/bench_sqrt.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal square root operations. Compares performance against Python's decimal module with 20 diverse test cases. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from python import Python, PythonObject from time import perf_counter_ns import time diff --git a/benches/bench_subtract.mojo b/benches/bench_subtract.mojo index 6046c78..4249bba 100644 --- a/benches/bench_subtract.mojo +++ b/benches/bench_subtract.mojo @@ -3,7 +3,7 @@ Comprehensive benchmarks for Decimal subtraction operations. Compares performance against Python's decimal module. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from python import Python, PythonObject from time import perf_counter_ns import time diff --git a/mojoproject.toml b/mojoproject.toml index b3b16c1..6f8e246 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -31,7 +31,8 @@ t = "clear && magic run test" test_arith = "magic run package && magic run mojo test tests/test_arithmetics.mojo && magic run delete_package" test_div = "magic run package && magic run mojo test tests/test_division.mojo && magic run delete_package" test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && magic run delete_package" -test_round = "magic run package && magic run mojo test tests/test_rounding.mojo && magic run delete_package" +test_round = "magic run package && magic run mojo test tests/test_round.mojo && magic run delete_package" +test_creation = "magic run package && magic run mojo test tests/test_creation.mojo && magic run delete_package" # benches bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" @@ -39,6 +40,7 @@ b = "clear && magic run bench" bench_mul = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .. && magic run delete_package" bench_div = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .. && magic run delete_package" bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo && cd .. && magic run delete_package" +bench_round = "magic run package && cd benches && magic run mojo bench_round.mojo && cd .. && magic run delete_package" # before commit final = "magic run test && magic run bench" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 3d31ed4..63d943d 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -10,18 +10,11 @@ DeciMojo: A fixed-point decimal arithmetic library in Mojo. You can import a list of useful objects in one line, e.g., ```mojo -from decimojo import decimojo, dm, Decimal, D +from decimojo.prelude import dm, Decimal, RoundingMode ``` - -where `decimojo` is the module itself, `dm` is an alias for the module, -`Decimal` is the `Decimal` type, and `D` is an alias for the `Decimal` type. """ -import decimojo -import decimojo as dm - from .decimal import Decimal -from .decimal import Decimal as D from .rounding_mode import RoundingMode @@ -35,4 +28,5 @@ from .maths import ( round, absolute, ) + from .logic import greater, greater_equal, less, less_equal, equal, not_equal diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index 85588d8..93e4de2 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -11,11 +11,13 @@ # # Organization of methods: # - Constructors and life time methods +# - Constructing methods that are not dunders # - Output dunders, type-transfer dunders, and other type-transfer methods # - Basic unary arithmetic operation dunders # - Basic binary arithmetic operation dunders # - Basic binary logic operation dunders -# - Other dunders that implements tratis +# - Other dunders that implements traits +# - Mathematical methods that do not implement a trait (not a dunder) # - Other methods # - Internal methods # @@ -34,7 +36,11 @@ Implements basic object methods for working with decimal numbers. """ -from .rounding_mode import RoundingMode +import decimojo.logic +import decimojo.maths +from decimojo.rounding_mode import RoundingMode +import decimojo.str +import decimojo.utility @register_passable @@ -91,9 +97,9 @@ struct Decimal( alias MAX_AS_UINT256 = UInt256(79228162514264337593543950335) alias MAX_AS_INT256 = Int256(79228162514264337593543950335) alias MAX_AS_STRING = String("79228162514264337593543950335") - """Maximum value as a string of a 128-bit Decimal.""" - alias MAX_VALUE_DIGITS = 29 - """Length of the max value as a string. For 128-bit Decimal, it is 29 digits""" + """Maximum value as a string.""" + alias MAX_NUM_DIGITS = 29 + """Number of digits of the max value 79228162514264337593543950335.""" alias SIGN_MASK = UInt32(0x80000000) """Sign mask. `0b1000_0000_0000_0000_0000_0000_0000_0000`. 1 bit for sign (0 is positive and 1 is negative).""" @@ -114,7 +120,7 @@ struct Decimal( Returns a Decimal representing positive infinity. Internal representation: `0b0000_0000_0000_0000_0000_0000_0001`. """ - return Decimal(0, 0, 0, 0x00000001) + return Decimal.from_raw_words(0, 0, 0, 0x00000001) @staticmethod fn NEGATIVE_INFINITY() -> Decimal: @@ -122,7 +128,7 @@ struct Decimal( Returns a Decimal representing negative infinity. Internal representation: `0b1000_0000_0000_0000_0000_0000_0001`. """ - return Decimal(0, 0, 0, 0x80000001) + return Decimal.from_raw_words(0, 0, 0, 0x80000001) @staticmethod fn NAN() -> Decimal: @@ -130,7 +136,7 @@ struct Decimal( Returns a Decimal representing Not a Number (NaN). Internal representation: `0b0000_0000_0000_0000_0000_0000_0010`. """ - return Decimal(0, 0, 0, 0x00000010) + return Decimal.from_raw_words(0, 0, 0, 0x00000010) @staticmethod fn NEGATIVE_NAN() -> Decimal: @@ -138,28 +144,28 @@ struct Decimal( Returns a Decimal representing negative Not a Number. Internal representation: `0b1000_0000_0000_0000_0000_0000_0010`. """ - return Decimal(0, 0, 0, 0x80000010) + return Decimal.from_raw_words(0, 0, 0, 0x80000010) @staticmethod fn ZERO() -> Decimal: """ Returns a Decimal representing 0. """ - return Decimal(0, 0, 0, 0) + return Decimal.from_raw_words(0, 0, 0, 0) @staticmethod fn ONE() -> Decimal: """ Returns a Decimal representing 1. """ - return Decimal(1, 0, 0, 0) + return Decimal.from_raw_words(1, 0, 0, 0) @staticmethod fn NEGATIVE_ONE() -> Decimal: """ Returns a Decimal representing -1. """ - return Decimal(1, 0, 0, Decimal.SIGN_MASK) + return Decimal.from_raw_words(1, 0, 0, Decimal.SIGN_MASK) @staticmethod fn MAX() -> Decimal: @@ -167,14 +173,16 @@ struct Decimal( Returns the maximum possible Decimal value. This is equivalent to 79228162514264337593543950335. """ - return Decimal(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0) + return Decimal.from_raw_words(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0) @staticmethod fn MIN() -> Decimal: """Returns the minimum possible Decimal value (negative of MAX). This is equivalent to -79228162514264337593543950335. """ - return Decimal(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, Decimal.SIGN_MASK) + return Decimal.from_raw_words( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, Decimal.SIGN_MASK + ) # ===------------------------------------------------------------------=== # # Constructors and life time dunder methods @@ -195,8 +203,8 @@ struct Decimal( mid: UInt32, high: UInt32, scale: UInt32, - negative: Bool, - ): + sign: Bool, + ) raises: """ Initializes a Decimal with five components. If the scale is greater than MAX_SCALE, it is set to MAX_SCALE. @@ -206,8 +214,17 @@ struct Decimal( mid: Middle 32 bits of coefficient. high: Most significant 32 bits of coefficient. scale: Number of decimal places (0-28). - negative: True if the number is negative. + sign: True if the number is negative. """ + + if scale > Self.MAX_SCALE: + raise Error( + String( + "Error in Decimal constructor with five components: Scale" + " must be between 0 and 28, but got {}" + ).format(scale) + ) + self.low = low self.mid = mid self.high = high @@ -219,34 +236,11 @@ struct Decimal( flags |= (scale << Self.SCALE_SHIFT) & Self.SCALE_MASK # Set the sign bit if negative - if negative: + if sign: flags |= Self.SIGN_MASK self.flags = flags - # Now check if we need to round due to exceeding MAX_SCALE - if scale > Self.MAX_SCALE: - # We need to properly round the value, not just change the scale - var scale_diff = scale - Self.MAX_SCALE - # The 'self' is already initialized above, so we can call _scale_down on it - self = self._scale_down(Int(scale_diff), RoundingMode.HALF_EVEN()) - - # No else needed as the value is already properly set if scale <= MAX_SCALE - - fn __init__( - out self, low: UInt32, mid: UInt32, high: UInt32, flags: UInt32 - ): - """ - Initializes a Decimal with internal representation fields. - Uses the full constructor to properly handle scaling and rounding. - """ - # Extract sign and scale from flags - var is_negative = (flags & Self.SIGN_MASK) != 0 - var scale = (flags & Self.SCALE_MASK) >> Self.SCALE_SHIFT - - # Use the previous constructor which handles scale rounding properly - self = Self(low, mid, high, scale, is_negative) - fn __init__(out self, integer: Int): """ Initializes a Decimal from an integer. @@ -326,8 +320,8 @@ struct Decimal( # TODO: Add arguments to specify the scale and sign for all integer constructors fn __init__( - out self, integer: UInt128, scale: UInt32 = 0, negative: Bool = False - ): + out self, integer: UInt128, scale: UInt32 = 0, sign: Bool = False + ) raises: """ Initializes a Decimal from an UInt128 value. ***WARNING***: This constructor can only handle values up to 96 bits. @@ -336,11 +330,17 @@ struct Decimal( var mid = UInt32((integer >> 32) & 0xFFFFFFFF) var high = UInt32((integer >> 64) & 0xFFFFFFFF) - self = Decimal(low, mid, high, scale, negative) + try: + self = Decimal(low, mid, high, scale, sign) + except e: + raise Error( + "Error in Decimal constructor with UInt128, scale, and sign: ", + e, + ) fn __init__( - out self, integer: UInt256, scale: UInt32 = 0, negative: Bool = False - ): + out self, integer: UInt256, scale: UInt32 = 0, sign: Bool = False + ) raises: """ Initializes a Decimal from an UInt256 value. ***WARNING***: This constructor can only handle values up to 96 bits. @@ -349,7 +349,14 @@ struct Decimal( var mid = UInt32((integer >> 32) & 0xFFFFFFFF) var high = UInt32((integer >> 64) & 0xFFFFFFFF) - self = Decimal(low, mid, high, scale, negative) + try: + self = Decimal(low, mid, high, scale, sign) + except e: + raise Error( + "Error in Decimal constructor with UInt256, scale, and sign: ", + e, + ) + self = Decimal(low, mid, high, scale, sign) fn __init__(out self, s: String) raises: """ @@ -534,8 +541,8 @@ struct Decimal( else: string_of_integral_part = String("0") - if (len(string_of_integral_part) > Decimal.MAX_VALUE_DIGITS) or ( - len(string_of_integral_part) == Decimal.MAX_VALUE_DIGITS + if (len(string_of_integral_part) > Decimal.MAX_NUM_DIGITS) or ( + len(string_of_integral_part) == Decimal.MAX_NUM_DIGITS and (string_of_integral_part > Self.MAX_AS_STRING) ): raise Error( @@ -547,8 +554,8 @@ struct Decimal( # Check if the coefficient is too large # Recursively re-calculate the coefficient string after truncating and rounding # until it fits within the Decimal limits - while (len(string_of_coefficient) > Decimal.MAX_VALUE_DIGITS) or ( - len(string_of_coefficient) == Decimal.MAX_VALUE_DIGITS + while (len(string_of_coefficient) > Decimal.MAX_NUM_DIGITS) or ( + len(string_of_coefficient) == Decimal.MAX_NUM_DIGITS and (string_of_coefficient > Self.MAX_AS_STRING) ): var raw_length_of_coefficient = len(string_of_coefficient) @@ -556,10 +563,10 @@ struct Decimal( # If string_of_coefficient has more than 29 digits, truncate it to 29. # If string_of_coefficient has 29 digits and larger than MAX_AS_STRING, truncate it to 28. var rounding_digit = string_of_coefficient[ - min(Decimal.MAX_VALUE_DIGITS, len(string_of_coefficient) - 1) + min(Decimal.MAX_NUM_DIGITS, len(string_of_coefficient) - 1) ] string_of_coefficient = string_of_coefficient[ - : min(Decimal.MAX_VALUE_DIGITS, len(string_of_coefficient) - 1) + : min(Decimal.MAX_NUM_DIGITS, len(string_of_coefficient) - 1) ] scale = scale - ( @@ -592,8 +599,8 @@ struct Decimal( result_chars.insert(0, String("1")) # If adding a digit would exceed max length, drop the last digit and reduce scale - if len(result_chars) > Decimal.MAX_VALUE_DIGITS: - result_chars = result_chars[: Decimal.MAX_VALUE_DIGITS] + if len(result_chars) > Decimal.MAX_NUM_DIGITS: + result_chars = result_chars[: Decimal.MAX_NUM_DIGITS] if scale > 0: scale -= 1 @@ -679,6 +686,27 @@ struct Decimal( self.high = other.high self.flags = other.flags + # ===------------------------------------------------------------------=== # + # Constructing methods that are not dunders + # ===------------------------------------------------------------------=== # + + @staticmethod + fn from_raw_words( + low: UInt32, mid: UInt32, high: UInt32, flags: UInt32 + ) -> Self: + """ + Initializes a Decimal with internal representation fields. + We do not check whether the scale is within the valid range. + """ + + var result = Decimal() + result.low = low + result.mid = mid + result.high = high + result.flags = flags + + return result + # ===------------------------------------------------------------------=== # # Output dunders, type-transfer dunders, and other type-transfer methods # ===------------------------------------------------------------------=== # @@ -814,7 +842,9 @@ struct Decimal( Returns: The absolute value of this Decimal. """ - var result = Decimal(self.low, self.mid, self.high, self.flags) + var result = Decimal.from_raw_words( + self.low, self.mid, self.high, self.flags + ) result.flags &= ~Self.SIGN_MASK # Clear sign bit return result @@ -825,7 +855,9 @@ struct Decimal( if self.is_zero(): return Decimal.ZERO() - var result = Decimal(self.low, self.mid, self.high, self.flags) + var result = Decimal.from_raw_words( + self.low, self.mid, self.high, self.flags + ) result.flags ^= Self.SIGN_MASK # Flip sign bit return result @@ -848,21 +880,21 @@ struct Decimal( """ try: - return decimojo.add(self, other) + return decimojo.maths.add(self, other) except e: raise Error("Error in `__add__()`; ", e) fn __add__(self, other: Float64) raises -> Self: - return decimojo.add(self, Decimal(other)) + return decimojo.maths.add(self, Decimal(other)) fn __add__(self, other: Int) raises -> Self: - return decimojo.add(self, Decimal(other)) + return decimojo.maths.add(self, Decimal(other)) fn __radd__(self, other: Float64) raises -> Self: - return decimojo.add(Decimal(other), self) + return decimojo.maths.add(Decimal(other), self) fn __radd__(self, other: Int) raises -> Self: - return decimojo.add(Decimal(other), self) + return decimojo.maths.add(Decimal(other), self) fn __sub__(self, other: Decimal) raises -> Self: """ @@ -887,52 +919,52 @@ struct Decimal( """ try: - return decimojo.subtract(self, other) + return decimojo.maths.subtract(self, other) except e: raise Error("Error in `__sub__()`; ", e) fn __sub__(self, other: Float64) raises -> Self: - return decimojo.subtract(self, Decimal(other)) + return decimojo.maths.subtract(self, Decimal(other)) fn __sub__(self, other: Int) raises -> Self: - return decimojo.subtract(self, Decimal(other)) + return decimojo.maths.subtract(self, Decimal(other)) fn __rsub__(self, other: Float64) raises -> Self: - return decimojo.subtract(Decimal(other), self) + return decimojo.maths.subtract(Decimal(other), self) fn __rsub__(self, other: Int) raises -> Self: - return decimojo.subtract(Decimal(other), self) + return decimojo.maths.subtract(Decimal(other), self) fn __mul__(self, other: Decimal) raises -> Self: """ Multiplies two Decimal values and returns a new Decimal containing the product. """ - return decimojo.multiply(self, other) + return decimojo.maths.multiply(self, other) fn __mul__(self, other: Float64) raises -> Self: - return decimojo.multiply(self, Decimal(other)) + return decimojo.maths.multiply(self, Decimal(other)) fn __mul__(self, other: Int) raises -> Self: - return decimojo.multiply(self, Decimal(other)) + return decimojo.maths.multiply(self, Decimal(other)) fn __truediv__(self, other: Decimal) raises -> Self: """ Divides this Decimal by another Decimal and returns a new Decimal containing the result. """ - return decimojo.true_divide(self, other) + return decimojo.maths.true_divide(self, other) fn __truediv__(self, other: Float64) raises -> Self: - return decimojo.true_divide(self, Decimal(other)) + return decimojo.maths.true_divide(self, Decimal(other)) fn __truediv__(self, other: Int) raises -> Self: - return decimojo.true_divide(self, Decimal(other)) + return decimojo.maths.true_divide(self, Decimal(other)) fn __rtruediv__(self, other: Float64) raises -> Self: - return decimojo.true_divide(Decimal(other), self) + return decimojo.maths.true_divide(Decimal(other), self) fn __rtruediv__(self, other: Int) raises -> Self: - return decimojo.true_divide(Decimal(other), self) + return decimojo.maths.true_divide(Decimal(other), self) fn __pow__(self, exponent: Decimal) raises -> Self: """ @@ -974,7 +1006,7 @@ struct Decimal( Returns: True if self is greater than other, False otherwise. """ - return decimojo.greater(self, other) + return decimojo.logic.greater(self, other) fn __ge__(self, other: Decimal) -> Bool: """ @@ -986,7 +1018,7 @@ struct Decimal( Returns: True if self is greater than or equal to other, False otherwise. """ - return decimojo.greater_equal(self, other) + return decimojo.logic.greater_equal(self, other) fn __lt__(self, other: Decimal) -> Bool: """ @@ -998,7 +1030,7 @@ struct Decimal( Returns: True if self is less than other, False otherwise. """ - return decimojo.less(self, other) + return decimojo.logic.less(self, other) fn __le__(self, other: Decimal) -> Bool: """ @@ -1010,7 +1042,7 @@ struct Decimal( Returns: True if self is less than or equal to other, False otherwise. """ - return decimojo.less_equal(self, other) + return decimojo.logic.less_equal(self, other) fn __eq__(self, other: Decimal) -> Bool: """ @@ -1022,7 +1054,7 @@ struct Decimal( Returns: True if self is equal to other, False otherwise. """ - return decimojo.equal(self, other) + return decimojo.logic.equal(self, other) fn __ne__(self, other: Decimal) -> Bool: """ @@ -1034,63 +1066,70 @@ struct Decimal( Returns: True if self is not equal to other, False otherwise. """ - return decimojo.not_equal(self, other) + return decimojo.logic.not_equal(self, other) # ===------------------------------------------------------------------=== # - # Other dunders that implements tratis + # Other dunders that implements traits # round # ===------------------------------------------------------------------=== # - fn __round__( - self, ndigits: Int = 0, mode: RoundingMode = RoundingMode.HALF_EVEN() - ) raises -> Self: + fn __round__(self, ndigits: Int) -> Self: """ Rounds this Decimal to the specified number of decimal places. + If `ndigits` is not given, rounds to 0 decimal places. + If rounding causes overflow, returns the value itself. - Args: - ndigits: Number of decimal places to round to. - If 0 (default), rounds to the nearest integer. - If positive, rounds to the given number of decimal places. - If negative, rounds to the left of the decimal point. - mode: The rounding mode to use. Defaults to RoundingMode.HALF_EVEN. + raises: + Error: Calling `round()` failed. + """ - Returns: - A new Decimal rounded to the specified precision + try: + return decimojo.maths.round( + self, ndigits=ndigits, rounding_mode=RoundingMode.HALF_EVEN() + ) + except e: + return self - Raises: - Error: If the operation would result in overflow. + fn __round__(self) -> Self: + """**OVERLOAD**.""" - Examples: - ``` - round(Decimal("3.14159"), 2) # Returns 3.14 - round("3.14159") # Returns 3 - round("1234.5", -2) # Returns 1200 - ``` - . - """ + return self.__round__(ndigits=0) - return decimojo.round(self, ndigits, mode) + # ===------------------------------------------------------------------=== # + # Mathematical methods that do not implement a trait (not a dunder) + # round, sqrt + # ===------------------------------------------------------------------=== # - fn __round__(self, ndigits: Int = 0) -> Self: + fn round( + self, + ndigits: Int = 0, + rounding_mode: RoundingMode = RoundingMode.ROUND_HALF_EVEN, + ) raises -> Self: """ - **OVERLOAD** Rounds this Decimal to the specified number of decimal places. - """ + Compared to `__round__`, this method: + (1) Allows specifying the rounding mode. + (2) Raises an error if the operation would result in overflow. - return decimojo.round(self, ndigits, RoundingMode.HALF_EVEN()) + Args: + ndigits: The number of decimal places to round to. + Default is 0. + rounding_mode: The rounding mode to use. + Default is RoundingMode.ROUND_HALF_EVEN. - fn __round__(self) -> Self: - """ - **OVERLOAD** - Rounds this Decimal to the specified number of decimal places. - """ + Returns: + The rounded Decimal value. - return decimojo.round(self, 0, RoundingMode.HALF_EVEN()) + Raises: + Error: If calling `round()` failed. + """ - # ===------------------------------------------------------------------=== # - # Methematical methods that do not implement a trait (not a dunder) - # sqrt - # ===------------------------------------------------------------------=== # + try: + return decimojo.maths.round( + self, ndigits=ndigits, rounding_mode=rounding_mode + ) + except e: + raise Error("Error in `Decimal.round()`; ", e) fn sqrt(self) raises -> Self: """ @@ -1103,7 +1142,7 @@ struct Decimal( Error: If the operation would result in overflow. """ - return decimojo.sqrt(self) + return decimojo.maths.sqrt(self) # ===------------------------------------------------------------------=== # # Other methods @@ -1162,6 +1201,26 @@ struct Decimal( """Returns True if this Decimal is negative.""" return (self.flags & Self.SIGN_MASK) != 0 + fn is_one(self) -> Bool: + """ + Returns True if this Decimal represents the value 1. + If 10^scale == coefficient, then it's one. + `1` and `1.00` are considered ones. + """ + if self.is_negative(): + return False + + var scale = self.scale() + var coef = self.coefficient() + + if scale == 0 and coef == 1: + return True + + if UInt128(10) ** scale == coef: + return True + + return False + fn is_zero(self) -> Bool: """ Returns True if this Decimal represents zero. @@ -1224,8 +1283,8 @@ struct Decimal( - Zero if |self| = |other| - Negative value if |self| < |other| """ - var abs_self = decimojo.absolute(self) - var abs_other = decimojo.absolute(other) + var abs_self = decimojo.maths.absolute(self) + var abs_other = decimojo.maths.absolute(other) if abs_self > abs_other: return 1 @@ -1354,54 +1413,3 @@ struct Decimal( ) return result - - fn _scale_up(self, owned scale_diff: Int) -> Decimal: - """ - Internal method to scale up a decimal by: - - multiplying coefficient by 10^scale_diff - - increase the scale by scale_diff - - Args: - scale_diff: Number of decimal places to scale up by - - Returns: - A new Decimal with the scaled up value - """ - var result = self - - # Early return if no scaling needed - if scale_diff <= 0: - return result - - # Update the scale in the flags - var new_scale = self.scale() + scale_diff - if new_scale > Self.MAX_SCALE + 1: - # Cannot scale beyond max precision, limit the scaling - scale_diff = Self.MAX_SCALE + 1 - self.scale() - new_scale = Self.MAX_SCALE + 1 - - # With UInt128, we can represent the coefficient as a single value - var coefficient = UInt128(self.high) << 64 | UInt128( - self.mid - ) << 32 | UInt128(self.low) - - # Check if multiplication by 10^scale_diff would cause overflow - var max_coefficient = ~UInt128(0) / UInt128(10**scale_diff) - if coefficient > max_coefficient: - # Handle overflow case - limit to maximum value or raise error - coefficient = ~UInt128(0) - else: - # No overflow - safe to multiply - coefficient *= UInt128(10**scale_diff) - - # Extract the 32-bit components from the UInt128 - result.low = UInt32(coefficient & 0xFFFFFFFF) - result.mid = UInt32((coefficient >> 32) & 0xFFFFFFFF) - result.high = UInt32((coefficient >> 64) & 0xFFFFFFFF) - - # Set the new scale - result.flags = (self.flags & ~Self.SCALE_MASK) | ( - UInt32(new_scale << Self.SCALE_SHIFT) & Self.SCALE_MASK - ) - - return result diff --git a/src/decimojo/logic.mojo b/src/decimojo/logic.mojo index 6fb3220..7c4cf46 100644 --- a/src/decimojo/logic.mojo +++ b/src/decimojo/logic.mojo @@ -28,6 +28,7 @@ Implements functions for comparison operations on Decimal objects. """ from decimojo.decimal import Decimal +import decimojo.utility fn greater(a: Decimal, b: Decimal) -> Bool: @@ -182,6 +183,9 @@ fn _compare_abs(a: Decimal, b: Decimal) -> Int: - Positive value if |a| > |b| - Zero if |a| = |b| - Negative value if |a| < |b| + + raises: + Error: Calling `scale_up()` failed. """ # Normalize scales by scaling up the one with smaller scale var scale_a = a.scale() @@ -192,10 +196,17 @@ fn _compare_abs(a: Decimal, b: Decimal) -> Int: var b_copy = b # Scale up the decimal with smaller scale to match the other + # TODO: Treat this error properly if scale_a < scale_b: - a_copy = a._scale_up(scale_b - scale_a) + try: + a_copy = decimojo.utility.scale_up(a, scale_b - scale_a) + except: + a_copy = a elif scale_b < scale_a: - b_copy = b._scale_up(scale_a - scale_b) + try: + b_copy = decimojo.utility.scale_up(b, scale_a - scale_b) + except: + b_copy = b # Now both have the same scale, compare integer components # Compare high parts first (most significant) diff --git a/src/decimojo/maths/basic.mojo b/src/decimojo/maths/basic.mojo index dc37d39..dbee003 100644 --- a/src/decimojo/maths/basic.mojo +++ b/src/decimojo/maths/basic.mojo @@ -21,8 +21,11 @@ Implements functions for mathematical operations on Decimal objects. """ +import testing + from decimojo.decimal import Decimal from decimojo.rounding_mode import RoundingMode +import decimojo.utility # TODO: Like `multiply` use combined bits to determine the appropriate method @@ -116,7 +119,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # Determine the scale for the result var scale = min( max(x1.scale(), x2.scale()), - Decimal.MAX_VALUE_DIGITS + Decimal.MAX_NUM_DIGITS - decimojo.utility.number_of_digits(summation), ) ## If summation > 7922816251426433759354395033 @@ -145,7 +148,7 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: # Determine the scale for the result var scale = min( max(x1.scale(), x2.scale()), - Decimal.MAX_VALUE_DIGITS + Decimal.MAX_NUM_DIGITS - decimojo.utility.number_of_digits(diff), ) ## If summation > 7922816251426433759354395033 @@ -345,7 +348,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var num_digits_to_keep = num_digits_prod - ( combined_scale - Decimal.MAX_SCALE ) - var truncated_prod = decimojo.utility.truncate_to_digits( + var truncated_prod = decimojo.utility.round_to_keep_first_n_digits( prod, num_digits_to_keep ) var final_scale = min(Decimal.MAX_SCALE, combined_scale) @@ -376,7 +379,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var num_digits_to_keep = num_digits_prod - ( combined_scale - Decimal.MAX_SCALE ) - var truncated_prod = decimojo.utility.truncate_to_digits( + var truncated_prod = decimojo.utility.round_to_keep_first_n_digits( prod, num_digits_to_keep ) var final_scale = min(Decimal.MAX_SCALE, combined_scale) @@ -442,7 +445,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: else: var num_digits = decimojo.utility.number_of_digits(prod) var final_scale = min( - Decimal.MAX_VALUE_DIGITS - num_digits, combined_scale + Decimal.MAX_NUM_DIGITS - num_digits, combined_scale ) # Scale up before it overflows prod = prod * 10**final_scale @@ -484,11 +487,22 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: var num_digits_to_keep = num_digits - ( combined_scale - Decimal.MAX_SCALE ) - prod = decimojo.utility.truncate_to_digits(prod, num_digits_to_keep) + prod = decimojo.utility.round_to_keep_first_n_digits( + prod, num_digits_to_keep + ) var final_scale = min(Decimal.MAX_SCALE, combined_scale) + + if final_scale > Decimal.MAX_SCALE: + var ndigits_prod = decimojo.utility.number_of_digits(prod) + prod = decimojo.utility.round_to_keep_first_n_digits( + prod, ndigits_prod - (final_scale - Decimal.MAX_SCALE) + ) + final_scale = Decimal.MAX_SCALE + var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, final_scale, is_negative) # SUB-CASE: Both operands are moderate @@ -506,10 +520,10 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: prod ) - combined_scale # Truncated first 29 digits - var truncated_prod_at_max_length = decimojo.utility.truncate_to_digits( - prod, Decimal.MAX_VALUE_DIGITS + var truncated_prod_at_max_length = decimojo.utility.round_to_keep_first_n_digits( + prod, Decimal.MAX_NUM_DIGITS ) - if (num_digits_of_integral_part >= Decimal.MAX_VALUE_DIGITS) & ( + if (num_digits_of_integral_part >= Decimal.MAX_NUM_DIGITS) & ( truncated_prod_at_max_length > Decimal.MAX_AS_UINT128 ): raise Error("Error in `multiply()`: Decimal overflow") @@ -519,14 +533,14 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # If the first 29 digits does not exceed the limit, # the final coefficient can be of 29 digits. # The final scale can be 29 - num_digits_of_integral_part. - var num_digits_of_decimal_part = Decimal.MAX_VALUE_DIGITS - num_digits_of_integral_part + var num_digits_of_decimal_part = Decimal.MAX_NUM_DIGITS - num_digits_of_integral_part # If the first 29 digits exceed the limit, # we need to adjust the num_digits_of_decimal_part by -1 # so that the final coefficient will be of 28 digits. if truncated_prod_at_max_length > Decimal.MAX_AS_UINT128: num_digits_of_decimal_part -= 1 - prod = decimojo.utility.truncate_to_digits( - prod, Decimal.MAX_VALUE_DIGITS - 1 + prod = decimojo.utility.round_to_keep_first_n_digits( + prod, Decimal.MAX_NUM_DIGITS - 1 ) else: prod = truncated_prod_at_max_length @@ -534,10 +548,18 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # Yuhao's notes: I think combined_scale should always be smaller var final_scale = min(num_digits_of_decimal_part, combined_scale) + if final_scale > Decimal.MAX_SCALE: + var ndigits_prod = decimojo.utility.number_of_digits(prod) + prod = decimojo.utility.round_to_keep_first_n_digits( + prod, ndigits_prod - (final_scale - Decimal.MAX_SCALE) + ) + final_scale = Decimal.MAX_SCALE + # Extract the 32-bit components from the UInt128 product var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) var high = UInt32((prod >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, final_scale, is_negative) # REMAINING CASES: Both operands are big @@ -554,11 +576,11 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: prod ) - combined_scale # Truncated first 29 digits - var truncated_prod_at_max_length = decimojo.utility.truncate_to_digits( - prod, Decimal.MAX_VALUE_DIGITS + var truncated_prod_at_max_length = decimojo.utility.round_to_keep_first_n_digits( + prod, Decimal.MAX_NUM_DIGITS ) # Check for overflow of the integral part after rounding - if (num_digits_of_integral_part >= Decimal.MAX_VALUE_DIGITS) & ( + if (num_digits_of_integral_part >= Decimal.MAX_NUM_DIGITS) & ( truncated_prod_at_max_length > Decimal.MAX_AS_UINT256 ): raise Error("Error in `multiply()`: Decimal overflow") @@ -568,14 +590,14 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # If the first 29 digits does not exceed the limit, # the final coefficient can be of 29 digits. # The final scale can be 29 - num_digits_of_integral_part. - var num_digits_of_decimal_part = Decimal.MAX_VALUE_DIGITS - num_digits_of_integral_part + var num_digits_of_decimal_part = Decimal.MAX_NUM_DIGITS - num_digits_of_integral_part # If the first 29 digits exceed the limit, # we need to adjust the num_digits_of_decimal_part by -1 # so that the final coefficient will be of 28 digits. if truncated_prod_at_max_length > Decimal.MAX_AS_UINT256: num_digits_of_decimal_part -= 1 - prod = decimojo.utility.truncate_to_digits( - prod, Decimal.MAX_VALUE_DIGITS - 1 + prod = decimojo.utility.round_to_keep_first_n_digits( + prod, Decimal.MAX_NUM_DIGITS - 1 ) else: prod = truncated_prod_at_max_length @@ -583,6 +605,13 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: # I think combined_scale should always be smaller final_scale = min(num_digits_of_decimal_part, combined_scale) + if final_scale > Decimal.MAX_SCALE: + var ndigits_prod = decimojo.utility.number_of_digits(prod) + prod = decimojo.utility.round_to_keep_first_n_digits( + prod, ndigits_prod - (final_scale - Decimal.MAX_SCALE) + ) + final_scale = Decimal.MAX_SCALE + # Extract the 32-bit components from the UInt256 product var low = UInt32(prod & 0xFFFFFFFF) var mid = UInt32((prod >> 32) & 0xFFFFFFFF) @@ -669,7 +698,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # If the result can be stored in UInt128 if ( decimojo.utility.number_of_digits(x1_coef) - diff_scale - < Decimal.MAX_VALUE_DIGITS + < Decimal.MAX_NUM_DIGITS ): var quot = x1_coef * UInt128(10) ** (-diff_scale) # print("DEBUG: quot", quot) @@ -743,7 +772,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # If the result can be stored in UInt128 if ( decimojo.utility.number_of_digits(quot) - diff_scale - < Decimal.MAX_VALUE_DIGITS + < Decimal.MAX_NUM_DIGITS ): var quot = quot * UInt128(10) ** (-diff_scale) var low = UInt32(quot & 0xFFFFFFFF) @@ -802,45 +831,53 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # Yuhao's notes: remainder should be positive beacuse the previous cases have been handled # 朱宇浩注: 餘數應該爲正,因爲之前的特例已經處理過了 - var x1_number_of_digits = decimojo.utility.number_of_digits(x1_coef) - var x2_number_of_digits = decimojo.utility.number_of_digits(x2_coef) - var diff_digits = x1_number_of_digits - x2_number_of_digits + var x1_ndigits = decimojo.utility.number_of_digits(x1_coef) + var x2_ndigits = decimojo.utility.number_of_digits(x2_coef) + var diff_digits = x1_ndigits - x2_ndigits # Here is an estimation of the maximum possible number of digits of the quotient's integral part # If it is higher than 28, we need to use UInt256 to store the quotient - var est_max_num_of_digits_of_quot_int_part = diff_digits - diff_scale + 1 - var is_use_uint128 = est_max_num_of_digits_of_quot_int_part < Decimal.MAX_VALUE_DIGITS + var est_max_ndigits_quot_int_part = diff_digits - diff_scale + 1 + var is_use_uint128 = est_max_ndigits_quot_int_part < Decimal.MAX_NUM_DIGITS # SUB-CASE: Use UInt128 to store the quotient # If the quotient's integral part is less than 28 digits, we can use UInt128 # if is_use_uint128: var quot: UInt128 var rem: UInt128 - var ajusted_scale = 0 + var adjusted_scale = 0 # The adjusted dividend coefficient will not exceed 2^96 - 1 if diff_digits < 0: var adjusted_x1_coef = x1_coef * UInt128(10) ** (-diff_digits) quot = adjusted_x1_coef // x2_coef rem = adjusted_x1_coef % x2_coef - ajusted_scale = -diff_digits + adjusted_scale = -diff_digits else: quot = x1_coef // x2_coef rem = x1_coef % x2_coef if is_use_uint128: - # Maximum number of steps is MAX_VALUE_DIGITS - num_digits_first_quot + 1 - # num_digis_first_quot is the number of digits of the quotient before using long division + # Maximum number of steps is minimum of the following two values: + # - MAX_NUM_DIGITS - ndigits_initial_quot + 1 + # - Decimal.MAX_SCALE - diff_scale - adjusted_scale + 1 (significant digits be rounded off) + # ndigits_initial_quot is the number of digits of the quotient before using long division # The extra digit is used for rounding up when it is 5 and not exact division - # 最大步數加一,用於捨去項爲5且非精確相除時向上捨去 # digit is the tempory quotient digit var digit = UInt128(0) # The final step counter stands for the number of dicimal points var step_counter = 0 - var num_digits_first_quot = decimojo.utility.number_of_digits(quot) - while (rem != 0) and ( - step_counter - < (Decimal.MAX_VALUE_DIGITS - num_digits_first_quot + 1) + var ndigits_initial_quot = decimojo.utility.number_of_digits(quot) + while ( + (rem != 0) + and ( + step_counter + < (Decimal.MAX_NUM_DIGITS - ndigits_initial_quot + 1) + ) + and ( + step_counter + < Decimal.MAX_SCALE - diff_scale - adjusted_scale + 1 + ) ): # Multiply remainder by 10 rem *= 10 @@ -868,35 +905,56 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # Not exact division, round up the last digit quot += 1 - var scale_of_quot = step_counter + diff_scale + ajusted_scale + var scale_of_quot = step_counter + diff_scale + adjusted_scale + # If the scale is negative, we need to scale up the quotient if scale_of_quot < 0: quot = quot * UInt128(10) ** (-scale_of_quot) scale_of_quot = 0 - var number_of_digits_quot = decimojo.utility.number_of_digits(quot) - var number_of_digits_quot_int_part = number_of_digits_quot - scale_of_quot + var ndigits_quot = decimojo.utility.number_of_digits(quot) + var ndigits_quot_int_part = ndigits_quot - scale_of_quot # If quot is within MAX, return the result if quot <= Decimal.MAX_AS_UINT128: + if scale_of_quot > Decimal.MAX_SCALE: + quot = decimojo.utility.round_to_keep_first_n_digits( + quot, + ndigits_quot - (scale_of_quot - Decimal.MAX_SCALE), + ) + scale_of_quot = Decimal.MAX_SCALE + var low = UInt32(quot & 0xFFFFFFFF) var mid = UInt32((quot >> 32) & 0xFFFFFFFF) var high = UInt32((quot >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, scale_of_quot, is_negative) # Otherwise, we need to truncate the first 29 or 28 digits else: - var truncated_quot = decimojo.utility.truncate_to_digits( - quot, Decimal.MAX_VALUE_DIGITS + var truncated_quot = decimojo.utility.round_to_keep_first_n_digits( + quot, Decimal.MAX_NUM_DIGITS ) var scale_of_truncated_quot = ( - Decimal.MAX_VALUE_DIGITS - number_of_digits_quot_int_part + Decimal.MAX_NUM_DIGITS - ndigits_quot_int_part ) + if truncated_quot > Decimal.MAX_AS_UINT128: - truncated_quot = decimojo.utility.truncate_to_digits( - quot, Decimal.MAX_VALUE_DIGITS - 1 + truncated_quot = decimojo.utility.round_to_keep_first_n_digits( + quot, Decimal.MAX_NUM_DIGITS - 1 ) scale_of_truncated_quot -= 1 + if scale_of_truncated_quot > Decimal.MAX_SCALE: + var num_digits_truncated_quot = decimojo.utility.number_of_digits( + truncated_quot + ) + truncated_quot = decimojo.utility.round_to_keep_first_n_digits( + truncated_quot, + num_digits_truncated_quot + - (scale_of_truncated_quot - Decimal.MAX_SCALE), + ) + scale_of_truncated_quot = Decimal.MAX_SCALE + var low = UInt32(truncated_quot & 0xFFFFFFFF) var mid = UInt32((truncated_quot >> 32) & 0xFFFFFFFF) var high = UInt32((truncated_quot >> 64) & 0xFFFFFFFF) @@ -909,7 +967,7 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # It is almost the same also the case above, so we just use the same code else: - # Maximum number of steps is MAX_VALUE_DIGITS - num_digits_first_quot + 1 + # Maximum number of steps is MAX_NUM_DIGITS - ndigits_initial_quot + 1 # The extra digit is used for rounding up when it is 5 and not exact division # 最大步數加一,用於捨去項爲5且非精確相除時向上捨去 @@ -919,10 +977,17 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var digit = UInt256(0) # The final step counter stands for the number of dicimal points var step_counter = 0 - var num_digits_first_quot = decimojo.utility.number_of_digits(quot256) - while (rem256 != 0) and ( - step_counter - < (Decimal.MAX_VALUE_DIGITS - num_digits_first_quot + 1) + var ndigits_initial_quot = decimojo.utility.number_of_digits(quot256) + while ( + (rem256 != 0) + and ( + step_counter + < (Decimal.MAX_NUM_DIGITS - ndigits_initial_quot + 1) + ) + and ( + step_counter + < Decimal.MAX_SCALE - diff_scale - adjusted_scale + 1 + ) ): # Multiply remainder by 10 rem256 *= 10 @@ -943,46 +1008,63 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: # Not exact division, round up the last digit quot256 += 1 - var scale_of_quot = step_counter + diff_scale + ajusted_scale + var scale_of_quot = step_counter + diff_scale + adjusted_scale + # If the scale is negative, we need to scale up the quotient if scale_of_quot < 0: quot256 = quot256 * UInt256(10) ** (-scale_of_quot) scale_of_quot = 0 - var number_of_digits_quot = decimojo.utility.number_of_digits(quot256) - var number_of_digits_quot_int_part = number_of_digits_quot - scale_of_quot + var ndigits_quot = decimojo.utility.number_of_digits(quot256) + var ndigits_quot_int_part = ndigits_quot - scale_of_quot # If quot is within MAX, return the result if quot256 <= Decimal.MAX_AS_UINT256: + if scale_of_quot > Decimal.MAX_SCALE: + quot256 = decimojo.utility.round_to_keep_first_n_digits( + quot256, + ndigits_quot - (scale_of_quot - Decimal.MAX_SCALE), + ) + scale_of_quot = Decimal.MAX_SCALE + var low = UInt32(quot256 & 0xFFFFFFFF) var mid = UInt32((quot256 >> 32) & 0xFFFFFFFF) var high = UInt32((quot256 >> 64) & 0xFFFFFFFF) + return Decimal(low, mid, high, scale_of_quot, is_negative) # Otherwise, we need to truncate the first 29 or 28 digits else: - var truncated_quot = decimojo.utility.truncate_to_digits( - quot256, Decimal.MAX_VALUE_DIGITS + var truncated_quot = decimojo.utility.round_to_keep_first_n_digits( + quot256, Decimal.MAX_NUM_DIGITS ) # If integer part of quot is more than max, raise error - if (number_of_digits_quot_int_part > Decimal.MAX_VALUE_DIGITS) or ( - (number_of_digits_quot_int_part == Decimal.MAX_VALUE_DIGITS) + if (ndigits_quot_int_part > Decimal.MAX_NUM_DIGITS) or ( + (ndigits_quot_int_part == Decimal.MAX_NUM_DIGITS) and (truncated_quot > Decimal.MAX_AS_UINT256) ): raise Error("Error in `true_divide()`: Decimal overflow") var scale_of_truncated_quot = ( - Decimal.MAX_VALUE_DIGITS - number_of_digits_quot_int_part + Decimal.MAX_NUM_DIGITS - ndigits_quot_int_part ) if truncated_quot > Decimal.MAX_AS_UINT256: - truncated_quot = decimojo.utility.truncate_to_digits( - quot256, Decimal.MAX_VALUE_DIGITS - 1 + truncated_quot = decimojo.utility.round_to_keep_first_n_digits( + quot256, Decimal.MAX_NUM_DIGITS - 1 ) scale_of_truncated_quot -= 1 - # print("DEBUG: truncated_quot", truncated_quot) - # print("DEBUG: scale_of_truncated_quot", scale_of_truncated_quot) + if scale_of_truncated_quot > Decimal.MAX_SCALE: + var num_digits_truncated_quot = decimojo.utility.number_of_digits( + truncated_quot + ) + truncated_quot = decimojo.utility.round_to_keep_first_n_digits( + truncated_quot, + num_digits_truncated_quot + - (scale_of_truncated_quot - Decimal.MAX_SCALE), + ) + scale_of_truncated_quot = Decimal.MAX_SCALE var low = UInt32(truncated_quot & 0xFFFFFFFF) var mid = UInt32((truncated_quot >> 32) & 0xFFFFFFFF) diff --git a/src/decimojo/maths/exp.mojo b/src/decimojo/maths/exp.mojo index 7465916..8152d49 100644 --- a/src/decimojo/maths/exp.mojo +++ b/src/decimojo/maths/exp.mojo @@ -19,6 +19,8 @@ import math as builtin_math import testing +import decimojo.utility + fn power(base: Decimal, exponent: Decimal) raises -> Decimal: """ @@ -128,9 +130,6 @@ fn sqrt(x: Decimal) raises -> Decimal: if x.is_zero(): return Decimal.ZERO() - if x == Decimal.ONE(): - return Decimal.ONE() - var x_coef: UInt128 = x.coefficient() var x_scale = x.scale() @@ -148,19 +147,17 @@ fn sqrt(x: Decimal) raises -> Decimal: # For numbers with even scale elif x_scale % 2 == 0: var float_sqrt = builtin_math.sqrt(Float64(x_coef)) - guess = Decimal(UInt128(float_sqrt), negative=False, scale=x_scale >> 1) + guess = Decimal(UInt128(float_sqrt), scale=x_scale >> 1, sign=False) # print("DEBUG: scale is even") # For numbers with odd scale else: var float_sqrt = builtin_math.sqrt(Float64(x_coef)) * Float64(3.15625) guess = Decimal( - UInt128(float_sqrt), negative=False, scale=(x_scale + 1) >> 1 + UInt128(float_sqrt), scale=(x_scale + 1) >> 1, sign=False ) # print("DEBUG: scale is odd") - # print("DEBUG: initial guess", guess) - # print("DEBUG: initial guess", guess) testing.assert_false(guess.is_zero(), "Initial guess should not be zero") @@ -197,7 +194,7 @@ fn sqrt(x: Decimal) raises -> Decimal: # No need to do this if the last digit of the coefficient of guess is not zero if guess_coef % 10 == 0: var num_digits_x_ceof = decimojo.utility.number_of_digits(x_coef) - var num_digits_x_sqrt_coef = (num_digits_x_ceof + 1) >> 1 + var num_digits_x_sqrt_coef = (num_digits_x_ceof >> 1) + 1 var num_digits_guess_coef = decimojo.utility.number_of_digits( guess_coef ) @@ -215,7 +212,9 @@ fn sqrt(x: Decimal) raises -> Decimal: else: # print("DEBUG: guess", guess) # print("DEBUG: guess_coef after removing trailing zeros", guess_coef) - if guess_coef * guess_coef == x_coef: + if (guess_coef * guess_coef == x_coef) or ( + guess_coef * guess_coef == x_coef * 10 + ): var low = UInt32(guess_coef & 0xFFFFFFFF) var mid = UInt32((guess_coef >> 32) & 0xFFFFFFFF) var high = UInt32((guess_coef >> 64) & 0xFFFFFFFF) diff --git a/src/decimojo/maths/rounding.mojo b/src/decimojo/maths/rounding.mojo index dfdbb7b..b27c53e 100644 --- a/src/decimojo/maths/rounding.mojo +++ b/src/decimojo/maths/rounding.mojo @@ -19,8 +19,11 @@ Implements functions for mathematical operations on Decimal objects. """ +import testing + from decimojo.decimal import Decimal from decimojo.rounding_mode import RoundingMode +import decimojo.utility # ===------------------------------------------------------------------------===# # Rounding @@ -29,37 +32,119 @@ from decimojo.rounding_mode import RoundingMode fn round( number: Decimal, - decimal_places: Int = 0, - rounding_mode: RoundingMode = RoundingMode.HALF_EVEN(), -) -> Decimal: + ndigits: Int = 0, + rounding_mode: RoundingMode = RoundingMode.ROUND_HALF_EVEN, +) raises -> Decimal: """ Rounds the Decimal to the specified number of decimal places. Args: number: The Decimal to round. - decimal_places: Number of decimal places to round to. + ndigits: Number of decimal places to round to. Defaults to 0. rounding_mode: Rounding mode to use. - Defaults to HALF_EVEN/banker's rounding. + Defaults to ROUND_HALF_EVEN (banker's rounding). Returns: A new Decimal rounded to the specified number of decimal places. """ - var current_scale = number.scale() + + # Number of decimal places of the number is equal to the scale of the number + var x_scale = number.scale() + # `ndigits` is equal to the scale of the final number + var scale_diff = ndigits - x_scale # CASE: If already at the desired scale - # Return a copy + # Return a copy directly + # 情况一:如果已经在所需的标度上, 直接返回其副本 + # # round(Decimal("123.456"), 3) -> Decimal("123.456") - if current_scale == decimal_places: + if scale_diff == 0: return number - # TODO: CASE: If the number is an integer - # Return with more or less zeros until the desired scale - # round(Decimal("123"), 2) -> Decimal("123.00") + var x_coef = number.coefficient() + var ndigits_of_x = decimojo.utility.number_of_digits(x_coef) + + # CASE: If ndigits is larger than the current scale + # Scale up the coefficient of the number to the desired scale + # If scaling up causes an overflow, raise an error + # 情况二:如果ndigits大于当前标度, 将係數放大 + # + # Examples: + # round(Decimal("123.456"), 5) -> Decimal("123.45600") + # round(Decimal("123.456"), 29) -> Error + + if scale_diff > 0: + # If the digits of result > 29, directly raise an error + if ndigits_of_x + scale_diff > Decimal.MAX_NUM_DIGITS: + raise Error( + String( + "Error in `round()`: `ndigits = {}` causes the number of" + " digits in the significant figures of the result (={})" + " exceeds the maximum capacity (={})." + ).format( + ndigits, + ndigits_of_x + scale_diff, + Decimal.MAX_NUM_DIGITS, + ) + ) + + # If the digits of result <= 29, calculate the result by scaling up + else: + var res_coef = x_coef * UInt128(10) ** scale_diff + + # If the digits of result == 29, but the result >= 2^96, raise an error + if (ndigits_of_x + scale_diff == Decimal.MAX_NUM_DIGITS) and ( + res_coef > Decimal.MAX_AS_UINT128 + ): + raise Error( + String( + "Error in `round()`: `ndigits = {}` causes the" + " significant digits of the result (={}) exceeds the" + " maximum capacity (={})." + ).format(ndigits, res_coef, Decimal.MAX_AS_UINT128) + ) + + # In other cases, return the result + else: + return Decimal( + res_coef, scale=ndigits, sign=number.is_negative() + ) + + # CASE: If ndigits is smaller than the current scale + # Scale down the coefficient of the number to the desired scale and round + # 情况三:如果ndigits小于当前标度, 将係數縮小, 然后捨去 + # + # If `ndigits` is negative, the result need to be scaled up again. + # + # Examples: + # round(Decimal("987.654321"), 3) -> Decimal("987.654") + # round(Decimal("987.654321"), -2) -> Decimal("1000") + # round(Decimal("987.654321"), -3) -> Decimal("1000") + # round(Decimal("987.654321"), -4) -> Decimal("0") + + else: + # scale_diff < 0 + # Calculate the number of digits to keep + var ndigits_to_keep = ndigits_of_x + scale_diff + + # Keep the first `ndigits_to_keep` digits with specified rounding mode + var res_coef = decimojo.utility.round_to_keep_first_n_digits( + x_coef, ndigits=ndigits_to_keep, rounding_mode=rounding_mode + ) + + if ndigits >= 0: + return Decimal(res_coef, scale=ndigits, sign=number.is_negative()) + + # if `ndigits` is negative and `ndigits_to_keep` >= 0, scale up the result + elif ndigits_to_keep >= 0: + res_coef *= UInt128(10) ** (-ndigits) + return Decimal(res_coef, scale=0, sign=number.is_negative()) - # If we need more decimal places, scale up - if decimal_places > current_scale: - return number._scale_up(decimal_places - current_scale) + # if `ndigits` is negative and `ndigits_to_keep` < 0, return 0 + else: + return Decimal.ZERO() - # Otherwise, scale down with the specified rounding mode - return number._scale_down(current_scale - decimal_places, rounding_mode) + # Add a fallback raise even if it seems unreachable + testing.assert_true(False, "Unreachable code path reached") + return number diff --git a/src/decimojo/prelude.mojo b/src/decimojo/prelude.mojo new file mode 100644 index 0000000..b4f070a --- /dev/null +++ b/src/decimojo/prelude.mojo @@ -0,0 +1,20 @@ +""" +Provides a list of things that can be imported at one time. +The list contains the functions or types that are the most essential for a user. + +You can use the following code to import them: + +```mojo +from decimojo.prelude import dm, Decimal, RoundingMode +``` + +Or + +```mojo +from decimojo.prelude import * +``` +""" + +import decimojo as dm +from decimojo.decimal import Decimal +from decimojo.rounding_mode import RoundingMode diff --git a/src/decimojo/rounding_mode.mojo b/src/decimojo/rounding_mode.mojo index 9fe27bd..6bf1968 100644 --- a/src/decimojo/rounding_mode.mojo +++ b/src/decimojo/rounding_mode.mojo @@ -17,13 +17,14 @@ struct RoundingMode: """ # alias - alias down = Self.DOWN() - alias half_up = Self.HALF_UP() - alias half_even = Self.HALF_EVEN() - alias up = Self.UP() + alias ROUND_DOWN = Self.DOWN() + alias ROUND_HALF_UP = Self.HALF_UP() + alias ROUND_HALF_EVEN = Self.HALF_EVEN() + alias ROUND_UP = Self.UP() # Internal value var value: Int + """Internal value representing the rounding mode.""" # Static constants for each rounding mode @staticmethod @@ -51,3 +52,18 @@ struct RoundingMode: fn __eq__(self, other: Self) -> Bool: return self.value == other.value + + fn __eq__(self, other: String) -> Bool: + return String(self) == other + + fn __str__(self) -> String: + if self == Self.DOWN(): + return "ROUND_DOWN" + elif self == Self.HALF_UP(): + return "ROUND_HALF_UP" + elif self == Self.HALF_EVEN(): + return "ROUND_HALF_EVEN" + elif self == Self.UP(): + return "ROUND_UP" + else: + return "UNKNOWN_ROUNDING_MODE" diff --git a/src/decimojo/utility.mojo b/src/decimojo/utility.mojo index 82f4cfa..97adfb5 100644 --- a/src/decimojo/utility.mojo +++ b/src/decimojo/utility.mojo @@ -50,6 +50,90 @@ fn bitcast[dtype: DType](dec: Decimal) -> Scalar[dtype]: return result +fn scale_up(value: Decimal, owned level: Int) raises -> Decimal: + """ + Increase the scale of a Decimal while keeping the value unchanged. + Internally, this means multiplying the coefficient by 10^scale_diff + and increasing the scale by scale_diff simultaneously. + + Args: + value: The Decimal to scale up. + level: Number of decimal places to scale up by. + + Returns: + A new Decimal with the scaled up value. + + Raises: + Error: If the level is less than 0. + + Examples: + + ```mojo + from decimojo import Decimal + from decimojo.utility import scale_up + var d1 = Decimal("5") # 5 + var d2 = scale_up(d1, 2) # Result: 5.00 (same value, different representation) + print(d1) # 5 + print(d2) # 5.00 + print(d2.scale()) # 2 + + var d3 = Decimal("123.456") # 123.456 + var d4 = scale_up(d3, 3) # Result: 123.456000 + print(d3) # 123.456 + print(d4) # 123.456000 + print(d4.scale()) # 6 + ``` + . + """ + + if level < 0: + raise Error("Error in `scale_up()`: Level must be greater than 0") + + # Early return if no scaling needed + if level == 0: + return value + + var result = value + + # Update the scale in the flags + var new_scale = value.scale() + level + + # TODO: Check if multiplication by 10^level would cause overflow + # If yes, then raise an error + if new_scale > Decimal.MAX_SCALE + 1: + # Cannot scale beyond max precision, limit the scaling + level = Decimal.MAX_SCALE + 1 - value.scale() + new_scale = Decimal.MAX_SCALE + 1 + + # With UInt128, we can represent the coefficient as a single value + var coefficient = UInt128(value.high) << 64 | UInt128( + value.mid + ) << 32 | UInt128(value.low) + + # TODO: Check if multiplication by 10^level would cause overflow + # If yes, then raise an error + # + var max_coefficient = ~UInt128(0) / UInt128(10**level) + if coefficient > max_coefficient: + # Handle overflow case - limit to maximum value or raise error + coefficient = ~UInt128(0) + else: + # No overflow - safe to multiply + coefficient *= UInt128(10**level) + + # Extract the 32-bit components from the UInt128 + result.low = UInt32(coefficient & 0xFFFFFFFF) + result.mid = UInt32((coefficient >> 32) & 0xFFFFFFFF) + result.high = UInt32((coefficient >> 64) & 0xFFFFFFFF) + + # Set the new scale + result.flags = (value.flags & ~Decimal.SCALE_MASK) | ( + UInt32(new_scale << Decimal.SCALE_SHIFT) & Decimal.SCALE_MASK + ) + + return result + + fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: """ Truncates a UInt256 or UInt128 value to be as closer to the max value of @@ -86,9 +170,9 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: else: # Calculate how many digits we need to truncate - # Calculate how many digits to keep (MAX_VALUE_DIGITS = 29) - var num_digits = number_of_digits(value) - var digits_to_remove = num_digits - Decimal.MAX_VALUE_DIGITS + # Calculate how many digits to keep (MAX_NUM_DIGITS = 29) + var ndigits = number_of_digits(value) + var digits_to_remove = ndigits - Decimal.MAX_NUM_DIGITS # Collect digits for rounding decision var divisor = ValueType(10) ** ValueType(digits_to_remove) @@ -178,69 +262,78 @@ fn truncate_to_max[dtype: DType, //](value: Scalar[dtype]) -> Scalar[dtype]: return truncated_value -# TODO: Evalulate whether this can replace truncate_to_max in some cases. +# TODO: Evaluate whether this can replace truncate_to_max in some cases. # TODO: Add rounding modes to this function. -fn truncate_to_digits[ +fn round_to_keep_first_n_digits[ dtype: DType, // -](value: Scalar[dtype], num_digits: Int) -> Scalar[dtype]: +]( + value: Scalar[dtype], + ndigits: Int, + rounding_mode: RoundingMode = RoundingMode.ROUND_HALF_EVEN, +) -> Scalar[dtype]: """ - Truncates a UInt256 or UInt128 value to the specified number of digits. - Uses banker's rounding (ROUND_HALF_EVEN) for any truncated digits. + Rounds and keeps the first n digits of a integral value. + Default to use banker's rounding (ROUND_HALF_EVEN) for any truncated digits. `792281625142643375935439503356` with digits 2 will be truncated to `79`. `997` with digits 2 will be truncated to `100`. - This is useful in two cases: + Parameters: + dtype: Must be either uint128 or uint256. + + Args: + value: The integral value to truncate. + ndigits: The number of significant digits to evaluate. + rounding_mode: The rounding mode to use. + + Constraints: + `dtype` must be either `DType.uint128` or `DType.uint256`. + + Returns: + The truncated value. + + Notes: + + This function is useful in two cases: + (1) When you want to evaluate whether the coefficient will overflow after rounding, just look the first N digits (after rounding). If the truncated value is larger than the maximum, then it will overflow. Then you need to either raise an error (in case scale = 0 or integral part overflows), or keep only the first 28 digits in the coefficient. + (2) When you want to round a value. - The function is useful in the following cases. + There are some examples: - When you want to apply a scale of 31 to the coefficient `997`, it will be + - When you want to apply a scale of 31 to the coefficient `997`, it will be `0.0000000000000000000000000000997` with 31 digits. However, we can only store 28 digits in the coefficient (Decimal.MAX_SCALE = 28). Therefore, we need to truncate the coefficient to 0 (`3 - (31 - 28)`) digits and round it to the nearest even number. The truncated ceofficient will be `1`. Note that `truncated_digits = 1` which is not equal to - `num_digits = 0`, meaning there is a rounding to next digit. + `ndigits = 0`, meaning there is a rounding to next digit. The final decimal value will be `0.0000000000000000000000000001`. - When you want to apply a scale of 29 to the coefficient `234567`, it will be - `0.00000000000000000000000234567` with 29 digits. However, we can only + - When you want to apply a scale of 29 to the coefficient `234567`, it will + be `0.00000000000000000000000234567` with 29 digits. However, we can only store 28 digits in the coefficient (Decimal.MAX_SCALE = 28). Therefore, we need to truncate the coefficient to 5 (`6 - (29 - 28)`) digits and round it to the nearest even number. - The truncated ceofficient will be `23457`. + The truncated coefficient will be `23457`. The final decimal value will be `0.0000000000000000000000023457`. - When you want to apply a scale of 5 to the coefficient `234567`, it will be - `2.34567` with 5 digits. - Since `num_digits_to_keep = 6 - (5 - 28) = 29`, + - When you want to apply a scale of 5 to the coefficient `234567`, it will + be `2.34567` with 5 digits. + Since `ndigits_to_keep = 6 - (5 - 28) = 29`, it is greater and equal to the number of digits of the input value. The function will return the value as it is. - It can also be used for rounding function. For example, if you want to round - `12.34567` (`1234567` with scale `5`) to 2 digits, + - It can also be used for rounding function. For example, if you want to + round `12.34567` (`1234567` with scale `5`) to 2 digits, the function input will be `234567` and `4 = (7 - 5) + 2`. That is (number of digits - scale) + number of rounding points. The output is `1235`. - - Parameters: - dtype: Must be either uint128 or uint256. - - Args: - value: The UInt256 value to truncate. - num_digits: The number of significant digits to evalulate. - - Constraints: - `dtype` must be either `DType.uint128` or `DType.uint256`. - - Returns: - The truncated UInt256 value, guaranteed to fit within 96 bits. """ alias ValueType = Scalar[dtype] @@ -250,49 +343,66 @@ fn truncate_to_digits[ "must be uint128 or uint256", ]() - if num_digits < 0: + # CASE: The number of digits is less than 0 + # Return 0. + # + # Example: + # 123_456 keep -1 digits => 0 + if ndigits < 0: return 0 - var num_significant_digits = number_of_digits(value) - # If the number of digits is less than or equal to the specified digits, - # return the value - if num_significant_digits <= num_digits: + var ndigits_of_x = number_of_digits(value) + + # CASE: If the number of digits is greater than or equal to the specified digits + # Return the value. + # + # Example: + # 123_456 keep 7 digits => 123_456 + if ndigits >= ndigits_of_x: return value + # CASE: If the number of digits is less than the specified digits + # Return the value. + # + # Example: + # 123_456 keep 4 digits => 1_235 else: # Calculate how many digits we need to truncate - # Calculate how many digits to keep (MAX_VALUE_DIGITS = 29) - var num_digits_to_remove = num_significant_digits - num_digits + # Calculate how many digits to keep (MAX_NUM_DIGITS = 29) + var ndigits_to_remove = ndigits_of_x - ndigits # Collect digits for rounding decision - divisor = ValueType(10) ** ValueType(num_digits_to_remove) - truncated_value = value // divisor + var divisor = ValueType(10) ** ValueType(ndigits_to_remove) + var truncated_value = value // divisor var remainder = value % divisor - # Get the most significant digit of the remainder for rounding - var rounding_digit = remainder // 10 ** (num_digits_to_remove - 1) - - # Check if we need to round up based on banker's rounding (ROUND_HALF_EVEN) - var round_up = False - - # If rounding digit is > 5, round up - if rounding_digit > 5: - round_up = True - # If rounding digit is 5, check if there are any non-zero digits after it - elif rounding_digit == 5: - var has_nonzero_after = remainder > 5 * 10 ** ( - num_digits_to_remove - 1 - ) - # If there are non-zero digits after, round up - if has_nonzero_after: - round_up = True - # Otherwise, round to even (round up if last kept digit is odd) - else: - round_up = (truncated_value % 2) == 1 + # If RoundingMode is ROUND_DOWN, just truncate the value + if rounding_mode == RoundingMode.ROUND_DOWN: + pass - # Apply rounding if needed - if round_up: - truncated_value += 1 + # If RoundingMode is ROUND_UP, round up the value if remainder is greater than 0 + elif rounding_mode == RoundingMode.ROUND_UP: + if remainder > 0: + truncated_value += 1 + + # If RoundingMode is ROUND_HALF_UP, round up the value if remainder is greater than 5 + elif rounding_mode == RoundingMode.ROUND_HALF_UP: + var cutoff_value = 5 * 10 ** (ndigits_to_remove - 1) + if remainder >= cutoff_value: + truncated_value += 1 + + # If RoundingMode is ROUND_HALF_EVEN, round to nearest even digit if equidistant + else: + var cutoff_value = 5 * 10 ** (ndigits_to_remove - 1) + if remainder > cutoff_value: + truncated_value += 1 + elif remainder == cutoff_value: + # If truncated_value is even, do not round up + # If truncated_value is odd, round up + truncated_value += truncated_value % 2 + else: + # Do nothing + pass return truncated_value diff --git a/tests/test_arithmetics.mojo b/tests/test_arithmetics.mojo index 61ac158..ade4bd5 100644 --- a/tests/test_arithmetics.mojo +++ b/tests/test_arithmetics.mojo @@ -2,7 +2,7 @@ Test Decimal arithmetic operations including addition, subtraction, and negation. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode import testing diff --git a/tests/test_conversions.mojo b/tests/test_conversions.mojo index 2b95a21..edcd586 100644 --- a/tests/test_conversions.mojo +++ b/tests/test_conversions.mojo @@ -3,7 +3,7 @@ Test Decimal conversion methods: __int__, __float__, and __str__ for different numerical cases. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode import testing import time diff --git a/tests/test_creation.mojo b/tests/test_creation.mojo index 2bcd15c..4aa2341 100644 --- a/tests/test_creation.mojo +++ b/tests/test_creation.mojo @@ -1,7 +1,7 @@ """ Test Decimal creation from integer, float, or string values. """ -from decimojo import Decimal +from decimojo.prelude import dm, Decimal, RoundingMode import testing @@ -398,10 +398,10 @@ fn test_decimal_from_components() raises: testing.assert_equal(max_scale.scale(), 28, "Maximum scale should be 28") # Test case 12: Overflow scale protection - var overflow_scale = Decimal(123, 0, 0, 100, False) - testing.assert_true( - overflow_scale.scale() <= 28, "Scale should be capped to max precision" - ) + try: + var _overflow_scale = Decimal(123, 0, 0, 100, False) + except: + print("Successfully caught overflow scale error") print("All component constructor tests passed!") diff --git a/tests/test_division.mojo b/tests/test_division.mojo index a66e627..3ae347b 100644 --- a/tests/test_division.mojo +++ b/tests/test_division.mojo @@ -3,7 +3,7 @@ Comprehensive test suite for Decimal division operations. Includes 100 test cases covering edge cases, precision limits, and various scenarios. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode import testing diff --git a/tests/test_logic.mojo b/tests/test_logic.mojo index 5712fd3..44486ae 100644 --- a/tests/test_logic.mojo +++ b/tests/test_logic.mojo @@ -2,7 +2,7 @@ Test Decimal logic operations for comparison, including basic comparisons, edge cases, special handling for zero values, and operator overloads. """ -from decimojo import Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from decimojo.logic import ( greater, greater_equal, diff --git a/tests/test_rounding.mojo b/tests/test_round.mojo similarity index 97% rename from tests/test_rounding.mojo rename to tests/test_round.mojo index b7e2faf..90a80a3 100644 --- a/tests/test_rounding.mojo +++ b/tests/test_round.mojo @@ -1,7 +1,7 @@ """ -Test Decimal rounding methods with different rounding modes and precision levels. +Test Decimal round methods with different rounding modes and precision levels. """ -from decimojo import dm, Decimal, RoundingMode +from decimojo.prelude import dm, Decimal, RoundingMode import testing diff --git a/tests/test_sqrt.mojo b/tests/test_sqrt.mojo index f5404b7..d8e53a6 100644 --- a/tests/test_sqrt.mojo +++ b/tests/test_sqrt.mojo @@ -1,7 +1,7 @@ """ Comprehensive tests for the sqrt function of the Decimal type. """ -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from decimojo import sqrt import testing diff --git a/tests/test_utility.mojo b/tests/test_utility.mojo index b9eaf13..b98bde8 100644 --- a/tests/test_utility.mojo +++ b/tests/test_utility.mojo @@ -5,7 +5,7 @@ Tests for the utility functions in the decimojo.utility module. from testing import assert_equal, assert_true import max -from decimojo import dm, Decimal +from decimojo.prelude import dm, Decimal, RoundingMode from decimojo.utility import truncate_to_max, number_of_digits @@ -161,61 +161,75 @@ fn test_truncate_to_max_banker_rounding() raises: print("✓ All truncate_to_max banker's rounding tests passed!") -fn test_truncate_to_digits() raises: - """Test the truncate_to_digits function for proper digit truncation and rounding. +fn test_round_to_keep_first_n_digits() raises: + """Test the round_to_keep_first_n_digits function for proper digit truncation and rounding. """ - print("Testing truncate_to_digits...") + print("Testing round_to_keep_first_n_digits...") # Test case 1: Value with more digits than to keep (round to nearest power of 10) var case1 = UInt128(997) var case1_expected = UInt128(1) - assert_equal(dm.utility.truncate_to_digits(case1, 0), case1_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case1, 0), case1_expected + ) # Test case 2: Value with one more digit than to keep var case2 = UInt128(234567) var case2_expected = UInt128(23457) - assert_equal(dm.utility.truncate_to_digits(case2, 5), case2_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case2, 5), case2_expected + ) # Test case 3: Value with fewer digits than to keep (should return original) var case3 = UInt128(234567) - assert_equal(dm.utility.truncate_to_digits(case3, 29), case3) + assert_equal(dm.utility.round_to_keep_first_n_digits(case3, 29), case3) # Test case 4: Test banker's rounding with 5 (round to even) var case4a = UInt128(12345) # Last digit is 5, preceding digit is even var case4a_expected = UInt128(1234) - assert_equal(dm.utility.truncate_to_digits(case4a, 4), case4a_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case4a, 4), case4a_expected + ) var case4b = UInt128(23455) # Last digit is 5, preceding digit is odd var case4b_expected = UInt128(2346) - assert_equal(dm.utility.truncate_to_digits(case4b, 4), case4b_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case4b, 4), case4b_expected + ) # Test case 5: Rounding down (< 5) var case5 = UInt128(12342) var case5_expected = UInt128(1234) - assert_equal(dm.utility.truncate_to_digits(case5, 4), case5_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case5, 4), case5_expected + ) # Test case 6: Rounding up (> 5) var case6 = UInt128(12347) var case6_expected = UInt128(1235) - assert_equal(dm.utility.truncate_to_digits(case6, 4), case6_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case6, 4), case6_expected + ) # Test case 7: Zero input var case7 = UInt128(0) - assert_equal(dm.utility.truncate_to_digits(case7, 5), UInt128(0)) + assert_equal(dm.utility.round_to_keep_first_n_digits(case7, 5), UInt128(0)) # Test case 8: Single digit input var case8 = UInt128(7) - assert_equal(dm.utility.truncate_to_digits(case8, 1), UInt128(7)) + assert_equal(dm.utility.round_to_keep_first_n_digits(case8, 1), UInt128(7)) assert_equal( - dm.utility.truncate_to_digits(case8, 0), UInt128(1) + dm.utility.round_to_keep_first_n_digits(case8, 0), UInt128(1) ) # Round to nearest power of 10 # Test case 9: Large value with UInt256 var case9 = UInt256(9876543210987654321) var case9_expected = UInt256(987654321098765432) - assert_equal(dm.utility.truncate_to_digits(case9, 18), case9_expected) + assert_equal( + dm.utility.round_to_keep_first_n_digits(case9, 18), case9_expected + ) - print("✓ All truncate_to_digits tests passed!") + print("✓ All round_to_keep_first_n_digits tests passed!") fn test_bitcast() raises: @@ -253,7 +267,7 @@ fn test_bitcast() raises: assert_equal(large_scale_coef, large_scale_bits) # Test case 6: Custom bit pattern - var test_decimal = Decimal(12345, 67890, 0xABCDEF, 0x55) + var test_decimal = Decimal.from_raw_words(12345, 67890, 0xABCDEF, 0x55) var test_coef = test_decimal.coefficient() var test_bits = dm.utility.bitcast[DType.uint128](test_decimal) assert_equal(test_coef, test_bits) @@ -280,7 +294,7 @@ fn test_all() raises: test_bitcast() print() - test_truncate_to_digits() + test_round_to_keep_first_n_digits() print() print("✓✓✓ All utility module tests passed! ✓✓✓")