diff --git a/benches/bench.mojo b/benches/bench.mojo index 623cb7e..669ae38 100644 --- a/benches/bench.mojo +++ b/benches/bench.mojo @@ -2,6 +2,8 @@ from bench_add import main as bench_add from bench_subtract import main as bench_subtract from bench_multiply import main as bench_multiply from bench_divide import main as bench_divide +from bench_floor_divide import main as bench_floor_divide +from bench_modulo import main as bench_modulo from bench_sqrt import main as bench_sqrt from bench_from_float import main as bench_from_float from bench_from_string import main as bench_from_string @@ -21,6 +23,8 @@ fn main() raises: bench_subtract() bench_multiply() bench_divide() + bench_floor_divide() + bench_modulo() bench_sqrt() bench_from_float() bench_from_string() diff --git a/benches/bench_floor_divide.mojo b/benches/bench_floor_divide.mojo new file mode 100644 index 0000000..2920b8f --- /dev/null +++ b/benches/bench_floor_divide.mojo @@ -0,0 +1,406 @@ +""" +Comprehensive benchmarks for Decimal floor division (//) operation. +Compares performance against Python's decimal module with diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_floor_divide_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark_floor_divide( + name: String, + dividend: String, + divisor: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal floor division with Python Decimal floor division. + + Args: + name: Name of the benchmark case. + dividend: String representation of the dividend. + divisor: String representation of the divisor. + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Dividend: " + dividend, log_file) + log_print("Divisor: " + divisor, log_file) + + # Set up Mojo and Python values + var mojo_dividend = Decimal(dividend) + var mojo_divisor = Decimal(divisor) + var pydecimal = Python.import_module("decimal") + var py_dividend = pydecimal.Decimal(dividend) + var py_divisor = pydecimal.Decimal(divisor) + + # Execute the operations once to verify correctness + try: + var mojo_result = mojo_dividend // mojo_divisor + var py_result = py_dividend // py_divisor + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = mojo_dividend // mojo_divisor + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation using // operator + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_dividend // py_divisor + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo //: " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python //: " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + except e: + log_print("Error occurred during benchmark: " + String(e), log_file) + log_print("Skipping this benchmark case", log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Floor Division (//) Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 10000 # Higher iterations as this operation should be fast + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning floor division (//) benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Basic integer floor division with no remainder + run_benchmark_floor_divide( + "Integer division, no remainder", + "10", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Basic integer floor division with remainder + run_benchmark_floor_divide( + "Integer division, with remainder", + "10", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Division with decimal values + run_benchmark_floor_divide( + "Decimal division", + "10.5", + "2.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Division resulting in a decimal value + run_benchmark_floor_divide( + "Division resulting in integer", + "5", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Division with different decimal places + run_benchmark_floor_divide( + "Different decimal places", + "10.75", + "1.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 6: Negative dividend, positive divisor + run_benchmark_floor_divide( + "Negative dividend, positive divisor", + "-10", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 7: Positive dividend, negative divisor + run_benchmark_floor_divide( + "Positive dividend, negative divisor", + "10", + "-3", + iterations, + log_file, + speedup_factors, + ) + + # Case 8: Negative dividend, negative divisor + run_benchmark_floor_divide( + "Negative dividend, negative divisor", + "-10", + "-3", + iterations, + log_file, + speedup_factors, + ) + + # Case 9: Decimal values, negative dividend, positive divisor + run_benchmark_floor_divide( + "Decimal, negative dividend, positive divisor", + "-10.5", + "3.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Division by 1 + run_benchmark_floor_divide( + "Division by 1", + "10", + "1", + iterations, + log_file, + speedup_factors, + ) + + # Case 11: Zero dividend + run_benchmark_floor_divide( + "Zero dividend", + "0", + "5", + iterations, + log_file, + speedup_factors, + ) + + # Case 12: Division by a decimal < 1 + run_benchmark_floor_divide( + "Division by decimal < 1", + "10", + "0.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Division resulting in a negative zero + run_benchmark_floor_divide( + "Division resulting in zero", + "0", + "-5", + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Large number division + run_benchmark_floor_divide( + "Large number division", + "1000000000", + "7", + iterations, + log_file, + speedup_factors, + ) + + # Case 15: Small number division + run_benchmark_floor_divide( + "Small number division", + "0.0000001", + "0.0000002", + iterations, + log_file, + speedup_factors, + ) + + # Case 16: Very large dividend and divisor + run_benchmark_floor_divide( + "Large dividend and divisor", + "123456789012345", + "987654321", + iterations, + log_file, + speedup_factors, + ) + + # Case 17: High precision dividend + run_benchmark_floor_divide( + "High precision dividend", + "3.14159265358979323846", + "1.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 18: Very close values + run_benchmark_floor_divide( + "Very close values", + "1.0000001", + "1", + iterations, + log_file, + speedup_factors, + ) + + # Case 19: Power of 10 values + run_benchmark_floor_divide( + "Power of 10 values", + "10000", + "100", + iterations, + log_file, + speedup_factors, + ) + + # Case 20: Edge values not quite reaching the next integer + run_benchmark_floor_divide( + "Edge values", + "9.9999999", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor (ignoring any cases that might have failed) + if len(speedup_factors) > 0: + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== Floor Division (//) Benchmark Summary ===", log_file) + log_print( + "Benchmarked: " + + String(len(speedup_factors)) + + " different floor division cases", + log_file, + ) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print( + "Average speedup: " + String(average_speedup) + "×", log_file + ) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format( + i + 1, round(speedup_factors[i], 2) + ), + log_file, + ) + else: + log_print("\nNo valid benchmark cases were completed", log_file) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/benches/bench_modulo.mojo b/benches/bench_modulo.mojo new file mode 100644 index 0000000..74e085c --- /dev/null +++ b/benches/bench_modulo.mojo @@ -0,0 +1,406 @@ +""" +Comprehensive benchmarks for Decimal modulo (%) operation. +Compares performance against Python's decimal module with diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_modulo_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark_modulo( + name: String, + dividend: String, + divisor: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal modulo with Python Decimal modulo. + + Args: + name: Name of the benchmark case. + dividend: String representation of the dividend. + divisor: String representation of the divisor. + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Dividend: " + dividend, log_file) + log_print("Divisor: " + divisor, log_file) + + # Set up Mojo and Python values + var mojo_dividend = Decimal(dividend) + var mojo_divisor = Decimal(divisor) + var pydecimal = Python.import_module("decimal") + var py_dividend = pydecimal.Decimal(dividend) + var py_divisor = pydecimal.Decimal(divisor) + + # Execute the operations once to verify correctness + try: + var mojo_result = mojo_dividend % mojo_divisor + var py_result = py_dividend % py_divisor + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = mojo_dividend % mojo_divisor + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_dividend % py_divisor + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo %: " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python %: " + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + except e: + log_print("Error occurred during benchmark: " + String(e), log_file) + log_print("Skipping this benchmark case", log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Modulo (%) Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 10000 # Higher iterations as this operation should be fast + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning modulo (%) benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Simple modulo with remainder + run_benchmark_modulo( + "Simple modulo with remainder", + "10", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Modulo with no remainder + run_benchmark_modulo( + "Modulo with no remainder", + "10", + "5", + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Modulo with decimal values + run_benchmark_modulo( + "Decimal values (even division)", + "10.5", + "3.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Modulo with different decimal places + run_benchmark_modulo( + "Different decimal places", + "10.75", + "2.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Modulo with modulus > dividend + run_benchmark_modulo( + "Modulus > dividend", + "3", + "10", + iterations, + log_file, + speedup_factors, + ) + + # Case 6: Negative dividend, positive divisor + run_benchmark_modulo( + "Negative dividend, positive divisor", + "-10", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 7: Positive dividend, negative divisor + run_benchmark_modulo( + "Positive dividend, negative divisor", + "10", + "-3", + iterations, + log_file, + speedup_factors, + ) + + # Case 8: Negative dividend, negative divisor + run_benchmark_modulo( + "Negative dividend, negative divisor", + "-10", + "-3", + iterations, + log_file, + speedup_factors, + ) + + # Case 9: Decimal values, negative dividend + run_benchmark_modulo( + "Decimal values, negative dividend", + "-10.5", + "3.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Decimal values with remainder, negative dividend + run_benchmark_modulo( + "Decimal with remainder, negative dividend", + "-10.5", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Case 11: Modulo by 1 + run_benchmark_modulo( + "Modulo by 1", + "10", + "1", + iterations, + log_file, + speedup_factors, + ) + + # Case 12: Zero dividend + run_benchmark_modulo( + "Zero dividend", + "0", + "5", + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Modulo with a decimal < 1 + run_benchmark_modulo( + "Divisor < 1", + "10", + "0.3", + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Large number modulo + run_benchmark_modulo( + "Large number modulo", + "1000000007", + "13", + iterations, + log_file, + speedup_factors, + ) + + # Case 15: Small number modulo + run_benchmark_modulo( + "Small number modulo", + "0.0000023", + "0.0000007", + iterations, + log_file, + speedup_factors, + ) + + # Case 16: Equal values + run_benchmark_modulo( + "Equal values", + "7.5", + "7.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 17: High precision values + run_benchmark_modulo( + "High precision values", + "3.14159265358979323846", + "1.5", + iterations, + log_file, + speedup_factors, + ) + + # Case 18: Values close to exact multiple + run_benchmark_modulo( + "Values close to exact multiple", + "9.999999", + "2", + iterations, + log_file, + speedup_factors, + ) + + # Case 19: Large integers modulo + run_benchmark_modulo( + "Large integers modulo", + "12345678901234567890", + "9876543210", + iterations, + log_file, + speedup_factors, + ) + + # Case 20: Values that generate cyclic patterns + run_benchmark_modulo( + "Values generating repeating patterns", + "1", + "3", + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor (ignoring any cases that might have failed) + if len(speedup_factors) > 0: + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== Modulo (%) Benchmark Summary ===", log_file) + log_print( + "Benchmarked: " + + String(len(speedup_factors)) + + " different modulo cases", + log_file, + ) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print( + "Average speedup: " + String(average_speedup) + "×", log_file + ) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format( + i + 1, round(speedup_factors[i], 2) + ), + log_file, + ) + else: + log_print("\nNo valid benchmark cases were completed", log_file) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/mojoproject.toml b/mojoproject.toml index 76cab69..d622a45 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -34,6 +34,8 @@ t = "clear && magic run test" test_arith = "magic run package && magic run mojo test tests/test_arithmetics.mojo && magic run delete_package" test_multiply = "magic run package && magic run mojo test tests/test_multiply.mojo && magic run delete_package" test_divide = "magic run package && magic run mojo test tests/test_divide.mojo && magic run delete_package" +test_floor_divide = "magic run package && magic run mojo test tests/test_floor_divide.mojo && magic run delete_package" +test_modulo = "magic run package && magic run mojo test tests/test_modulo.mojo && magic run delete_package" test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && magic run delete_package" test_root = "magic run package && magic run mojo test tests/test_root.mojo && magic run delete_package" test_round = "magic run package && magic run mojo test tests/test_round.mojo && magic run delete_package" @@ -56,6 +58,8 @@ bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. & b = "clear && magic run bench" bench_multiply = "magic run package && cd benches && magic run mojo bench_multiply.mojo && cd .. && magic run delete_package" bench_divide = "magic run package && cd benches && magic run mojo bench_divide.mojo && cd .. && magic run delete_package" +bench_floor_divide = "magic run package && cd benches && magic run mojo bench_floor_divide.mojo && cd .. && magic run delete_package" +bench_modulo = "magic run package && cd benches && magic run mojo bench_modulo.mojo && cd .. && magic run delete_package" bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo && cd .. && magic run delete_package" bench_root = "magic run package && cd benches && magic run mojo bench_root.mojo && cd .. && magic run delete_package" bench_round = "magic run package && cd benches && magic run mojo bench_round.mojo && cd .. && magic run delete_package" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 3ed4b28..3db9daa 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -37,7 +37,9 @@ from .arithmetics import ( absolute, negative, multiply, - true_divide, + divide, + floor_divide, + modulo, ) from .comparison import ( diff --git a/src/decimojo/arithmetics.mojo b/src/decimojo/arithmetics.mojo index 5bbf0f2..4a2abf4 100644 --- a/src/decimojo/arithmetics.mojo +++ b/src/decimojo/arithmetics.mojo @@ -134,9 +134,12 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: if x1_coef > x2_coef: diff = x1_coef - x2_coef is_negative = x1.is_negative() - else: + elif x1_coef < x2_coef: diff = x2_coef - x1_coef is_negative = x2.is_negative() + else: # x1_coef == x2_coef + diff = UInt128(0) + is_negative = False return Decimal.from_uint128(diff, 0, is_negative) @@ -172,9 +175,12 @@ fn add(x1: Decimal, x2: Decimal) raises -> Decimal: if x1_coef > x2_coef: diff = x1.to_uint128() - x2.to_uint128() is_negative = x1.is_negative() - else: + elif x1_coef < x2_coef: diff = x2.to_uint128() - x1.to_uint128() is_negative = x2.is_negative() + else: # x1_coef == x2_coef + diff = UInt128(0) + is_negative = False # Determine the scale for the result var scale = min( @@ -708,7 +714,7 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: return Decimal(low, mid, high, final_scale, is_negative) -fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: +fn divide(x1: Decimal, x2: Decimal) raises -> Decimal: """ Divides x1 by x2 and returns a new Decimal containing the quotient. Uses a simpler string-based long division approach as fallback. @@ -1130,3 +1136,37 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal: var high = UInt32((truncated_quot >> 64) & 0xFFFFFFFF) return Decimal(low, mid, high, scale_of_truncated_quot, is_negative) + + +fn floor_divide(x1: Decimal, x2: Decimal) raises -> Decimal: + """Returns the integral part of the true quotient (truncating towards zero). + The following identity always holds: x_1 == (x_1 // x_2) * x_2 + x_1 % x_2. + + Args: + x1: The dividend. + x2: The divisor. + + Returns: + A new Decimal containing the integral part of x1 / x2. + """ + try: + return divide(x1, x2).round(0, RoundingMode.ROUND_DOWN) + except e: + raise Error("Error in `divide()`: ", e) + + +fn modulo(x1: Decimal, x2: Decimal) raises -> Decimal: + """Returns the remainder of the division of x1 by x2. + The following identity always holds: x_1 == (x_1 // x_2) * x_2 + x_1 % x_2. + + Args: + x1: The dividend. + x2: The divisor. + + Returns: + A new Decimal containing the remainder of x1 / x2. + """ + try: + return x1 - (floor_divide(x1, x2) * x2) + except e: + raise Error("Error in `modulo()`: ", e) diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index fcd32c7..74eed37 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -69,11 +69,6 @@ struct Decimal( The value of the coefficient is: `high * 2**64 + mid * 2**32 + low` The final value is: `(-1)**sign * coefficient * 10**(-scale)` - - Reference: - - - General Decimal Arithmetic Specification Version 1.70 – 7 Apr 2009 (https://speleotrove.com/decimal/decarith.html) - - https://learn.microsoft.com/en-us/dotnet/api/system.decimal.getbits?view=net-9.0&redirectedfrom=MSDN#System_Decimal_GetBits_System_Decimal_ """ # ===------------------------------------------------------------------=== # @@ -1187,11 +1182,27 @@ struct Decimal( @always_inline fn __truediv__(self, other: Self) raises -> Self: - return decimojo.arithmetics.true_divide(self, other) + return decimojo.arithmetics.divide(self, other) @always_inline fn __truediv__(self, other: Int) raises -> Self: - return decimojo.arithmetics.true_divide(self, Decimal(other)) + return decimojo.arithmetics.divide(self, Decimal(other)) + + @always_inline + fn __floordiv__(self, other: Self) raises -> Self: + return decimojo.arithmetics.floor_divide(self, other) + + @always_inline + fn __floordiv__(self, other: Int) raises -> Self: + return decimojo.arithmetics.floor_divide(self, Decimal(other)) + + @always_inline + fn __mod__(self, other: Self) raises -> Self: + return decimojo.arithmetics.modulo(self, other) + + @always_inline + fn __mod__(self, other: Int) raises -> Self: + return decimojo.arithmetics.modulo(self, Decimal(other)) @always_inline fn __pow__(self, exponent: Self) raises -> Self: @@ -1210,31 +1221,27 @@ struct Decimal( @always_inline fn __radd__(self, other: Int) raises -> Self: - try: - return decimojo.arithmetics.add(Decimal(other), self) - except e: - raise Error("Error in `__radd__()`: ", e) + return decimojo.arithmetics.add(Decimal(other), self) @always_inline fn __rsub__(self, other: Int) raises -> Self: - try: - return decimojo.arithmetics.subtract(Decimal(other), self) - except e: - raise Error("Error in `__rsub__()`: ", e) + return decimojo.arithmetics.subtract(Decimal(other), self) @always_inline fn __rmul__(self, other: Int) raises -> Self: - try: - return decimojo.arithmetics.multiply(Decimal(other), self) - except e: - raise Error("Error in `__rmul__()`: ", e) + return decimojo.arithmetics.multiply(Decimal(other), self) @always_inline fn __rtruediv__(self, other: Int) raises -> Self: - try: - return decimojo.arithmetics.true_divide(Decimal(other), self) - except e: - raise Error("Error in `__rtruediv__()`: ", e) + return decimojo.arithmetics.divide(Decimal(other), self) + + @always_inline + fn __rfloordiv__(self, other: Int) raises -> Self: + return decimojo.arithmetics.floor_divide(Decimal(other), self) + + @always_inline + fn __rmod__(self, other: Int) raises -> Self: + return decimojo.arithmetics.modulo(Decimal(other), self) # ===------------------------------------------------------------------=== # # Basic binary augmented arithmetic assignments dunders @@ -1269,11 +1276,19 @@ struct Decimal( @always_inline fn __itruediv__(mut self, other: Self) raises: - self = decimojo.arithmetics.true_divide(self, other) + self = decimojo.arithmetics.divide(self, other) @always_inline fn __itruediv__(mut self, other: Int) raises: - self = decimojo.arithmetics.true_divide(self, Decimal(other)) + self = decimojo.arithmetics.divide(self, Decimal(other)) + + @always_inline + fn __ifloordiv__(mut self, other: Self) raises: + self = decimojo.arithmetics.floor_divide(self, other) + + @always_inline + fn __ifloordiv__(mut self, other: Int) raises: + self = decimojo.arithmetics.floor_divide(self, Decimal(other)) # ===------------------------------------------------------------------=== # # Basic binary comparison operation dunders @@ -1338,7 +1353,9 @@ struct Decimal( """ try: return decimojo.rounding.round( - self, ndigits=ndigits, rounding_mode=RoundingMode.HALF_EVEN() + self, + ndigits=ndigits, + rounding_mode=RoundingMode.ROUND_HALF_EVEN, ) except e: return self @@ -1348,7 +1365,7 @@ struct Decimal( """**OVERLOAD**.""" try: return decimojo.rounding.round( - self, ndigits=0, rounding_mode=RoundingMode.HALF_EVEN() + self, ndigits=0, rounding_mode=RoundingMode.ROUND_HALF_EVEN ) except e: return self @@ -1537,18 +1554,6 @@ struct Decimal( """Returns the scale (number of decimal places) of this Decimal.""" return Int((self.flags & Self.SCALE_MASK) >> Self.SCALE_SHIFT) - @always_inline - fn scientific_exponent(self) -> Int: - """ - Calculates the exponent for scientific notation representation of a Decimal. - The exponent is the power of 10 needed to represent the value in scientific notation. - """ - # Special case for zero - if self.is_zero(): - return self.scale() - - return self.number_of_significant_digits() - 1 - self.scale() - @always_inline fn number_of_significant_digits(self) -> Int: """Returns the number of significant digits in the Decimal. diff --git a/src/decimojo/rounding_mode.mojo b/src/decimojo/rounding_mode.mojo index e424bcc..f4c1333 100644 --- a/src/decimojo/rounding_mode.mojo +++ b/src/decimojo/rounding_mode.mojo @@ -38,10 +38,10 @@ struct RoundingMode: """ # alias - alias ROUND_DOWN = Self.DOWN() - alias ROUND_HALF_UP = Self.HALF_UP() - alias ROUND_HALF_EVEN = Self.HALF_EVEN() - alias ROUND_UP = Self.UP() + alias ROUND_DOWN = Self.down() + alias ROUND_HALF_UP = Self.half_up() + alias ROUND_HALF_EVEN = Self.half_even() + alias ROUND_UP = Self.up() # Internal value var value: Int @@ -49,22 +49,22 @@ struct RoundingMode: # Static constants for each rounding mode @staticmethod - fn DOWN() -> Self: + fn down() -> Self: """Truncate (toward zero).""" return Self(0) @staticmethod - fn HALF_UP() -> Self: + fn half_up() -> Self: """Round away from zero if >= 0.5.""" return Self(1) @staticmethod - fn HALF_EVEN() -> Self: + fn half_even() -> Self: """Round to nearest even digit if equidistant (banker's rounding).""" return Self(2) @staticmethod - fn UP() -> Self: + fn up() -> Self: """Round away from zero.""" return Self(3) @@ -78,13 +78,13 @@ struct RoundingMode: return String(self) == other fn __str__(self) -> String: - if self == Self.DOWN(): + if self == Self.ROUND_DOWN: return "ROUND_DOWN" - elif self == Self.HALF_UP(): + elif self == Self.ROUND_UP: return "ROUND_HALF_UP" - elif self == Self.HALF_EVEN(): + elif self == Self.ROUND_HALF_EVEN: return "ROUND_HALF_EVEN" - elif self == Self.UP(): + elif self == Self.ROUND_UP: return "ROUND_UP" else: return "UNKNOWN_ROUNDING_MODE" diff --git a/tests/test_floor_divide.mojo b/tests/test_floor_divide.mojo new file mode 100644 index 0000000..4c8195e --- /dev/null +++ b/tests/test_floor_divide.mojo @@ -0,0 +1,282 @@ +""" +Comprehensive tests for the floor division (//) operation of the Decimal type. +Tests various scenarios to ensure proper integer division behavior. +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode + + +fn test_basic_floor_division() raises: + """Test basic integer floor division.""" + print("Testing basic floor division...") + + # Test case 1: Simple integer division with no remainder + var a1 = Decimal(10) + var b1 = Decimal(2) + var result1 = a1 // b1 + testing.assert_equal( + String(result1), "5", "10 // 2 should equal 5, got " + String(result1) + ) + + # Test case 2: Simple integer division with remainder + var a2 = Decimal(10) + var b2 = Decimal(3) + var result2 = a2 // b2 + testing.assert_equal( + String(result2), "3", "10 // 3 should equal 3, got " + String(result2) + ) + + # Test case 3: Division with decimal values + var a3 = Decimal("10.5") + var b3 = Decimal("2.5") + var result3 = a3 // b3 + testing.assert_equal( + String(result3), + "4", + "10.5 // 2.5 should equal 4, got " + String(result3), + ) + + # Test case 4: Division resulting in a decimal value + var a4 = Decimal(5) + var b4 = Decimal(2) + var result4 = a4 // b4 + testing.assert_equal( + String(result4), "2", "5 // 2 should equal 2, got " + String(result4) + ) + + # Test case 5: Division with different decimal places + var a5 = Decimal("10.75") + var b5 = Decimal("1.5") + var result5 = a5 // b5 + testing.assert_equal( + String(result5), + "7", + "10.75 // 1.5 should equal 7, got " + String(result5), + ) + + print("✓ Basic floor division tests passed!") + + +fn test_negative_floor_division() raises: + """Test floor division involving negative numbers.""" + print("Testing floor division with negative numbers...") + + # Test case 1: Negative // Positive + var a1 = Decimal(-10) + var b1 = Decimal(3) + var result1 = a1 // b1 + testing.assert_equal( + String(result1), + "-3", + "-10 // 3 should equal -3, got " + String(result1), + ) + + # Test case 2: Positive // Negative + var a2 = Decimal(10) + var b2 = Decimal(-3) + var result2 = a2 // b2 + testing.assert_equal( + String(result2), + "-3", + "10 // -3 should equal -3, got " + String(result2), + ) + + # Test case 3: Negative // Negative + var a3 = Decimal(-10) + var b3 = Decimal(-3) + var result3 = a3 // b3 + testing.assert_equal( + String(result3), "3", "-10 // -3 should equal 3, got " + String(result3) + ) + + # Test case 4: Decimal values, Negative // Positive + var a4 = Decimal("-10.5") + var b4 = Decimal("3.5") + var result4 = a4 // b4 + testing.assert_equal( + String(result4), + "-3", + "-10.5 // 3.5 should equal -3, got " + String(result4), + ) + + # Test case 5: Decimal values, Positive // Negative + var a5 = Decimal("10.5") + var b5 = Decimal("-3.5") + var result5 = a5 // b5 + testing.assert_equal( + String(result5), + "-3", + "10.5 // -3.5 should equal -3, got " + String(result5), + ) + + # Test case 6: Decimal values, Negative // Negative + var a6 = Decimal("-10.5") + var b6 = Decimal("-3.5") + var result6 = a6 // b6 + testing.assert_equal( + String(result6), + "3", + "-10.5 // -3.5 should equal 3, got " + String(result6), + ) + + print("✓ Negative number floor division tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for floor division.""" + print("Testing floor division edge cases...") + + # Test case 1: Division by 1 + var a1 = Decimal(10) + var b1 = Decimal(1) + var result1 = a1 // b1 + testing.assert_equal( + String(result1), "10", "10 // 1 should equal 10, got " + String(result1) + ) + + # Test case 2: Zero dividend + var a2 = Decimal(0) + var b2 = Decimal(5) + var result2 = a2 // b2 + testing.assert_equal( + String(result2), "0", "0 // 5 should equal 0, got " + String(result2) + ) + + # Test case 3: Division by a decimal < 1 + var a3 = Decimal(10) + var b3 = Decimal("0.5") + var result3 = a3 // b3 + testing.assert_equal( + String(result3), + "20", + "10 // 0.5 should equal 20, got " + String(result3), + ) + + # Test case 4: Division resulting in a negative zero (should be 0) + var a4 = Decimal(0) + var b4 = Decimal(-5) + var result4 = a4 // b4 + testing.assert_equal( + String(result4), "0", "0 // -5 should equal 0, got " + String(result4) + ) + + # Test case 5: Division by zero (should raise error) + var a5 = Decimal(10) + var b5 = Decimal(0) + var exception_caught = False + try: + var _result5 = a5 // b5 + testing.assert_equal(True, False, "Division by zero should raise error") + except: + exception_caught = True + testing.assert_equal( + exception_caught, True, "Division by zero should raise error" + ) + + # Test case 6: Large number division + var a6 = Decimal("1000000000") + var b6 = Decimal("7") + var result6 = a6 // b6 + testing.assert_equal( + String(result6), "142857142", "1000000000 // 7 calculated incorrectly" + ) + + # Test case 7: Small number division + var a7 = Decimal("0.0000001") + var b7 = Decimal("0.0000002") + var result7 = a7 // b7 + testing.assert_equal( + String(result7), "0", "0.0000001 // 0.0000002 should equal 0" + ) + + print("✓ Edge cases tests passed!") + + +fn test_mathematical_relationships() raises: + """Test mathematical relationships involving floor division.""" + print("Testing mathematical relationships...") + + # Test case 1: a = (a // b) * b + (a % b) + var a1 = Decimal(10) + var b1 = Decimal(3) + var floor_div = a1 // b1 + var mod_result = a1 % b1 + var reconstructed = floor_div * b1 + mod_result + testing.assert_equal( + String(reconstructed), + String(a1), + "a should equal (a // b) * b + (a % b)", + ) + + # Test case 2: a // b = floor(a / b) + var a2 = Decimal("10.5") + var b2 = Decimal("2.5") + var floor_div2 = a2 // b2 + var div_floored = (a2 / b2).round(0, RoundingMode.ROUND_DOWN) + testing.assert_equal( + String(floor_div2), + String(div_floored), + "a // b should equal floor(a / b)", + ) + + # Test case 3: Relationship with negative values + var a3 = Decimal(-10) + var b3 = Decimal(3) + var floor_div3 = a3 // b3 + var mod_result3 = a3 % b3 + var reconstructed3 = floor_div3 * b3 + mod_result3 + testing.assert_equal( + String(reconstructed3), + String(a3), + "a should equal (a // b) * b + (a % b) with negative values", + ) + + # Test case 4: (a // b) * b ≤ a < (a // b + 1) * b + var a4 = Decimal("10.5") + var b4 = Decimal("3.2") + var floor_div4 = a4 // b4 + var lower_bound = floor_div4 * b4 + var upper_bound = (floor_div4 + Decimal(1)) * b4 + testing.assert_true( + (lower_bound <= a4) and (a4 < upper_bound), + "Relationship (a // b) * b ≤ a < (a // b + 1) * b should hold", + ) + + print("✓ Mathematical relationships tests passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Decimal Floor Division Tests") + print("=========================================") + + run_test_with_error_handling( + test_basic_floor_division, "Basic floor division test" + ) + run_test_with_error_handling( + test_negative_floor_division, "Negative number floor division test" + ) + run_test_with_error_handling(test_edge_cases, "Edge cases test") + run_test_with_error_handling( + test_mathematical_relationships, "Mathematical relationships test" + ) + + print("All floor division tests passed!") diff --git a/tests/test_modulo.mojo b/tests/test_modulo.mojo new file mode 100644 index 0000000..fbf0d15 --- /dev/null +++ b/tests/test_modulo.mojo @@ -0,0 +1,348 @@ +""" +Comprehensive tests for the modulo (%) operation of the Decimal type. +Tests various scenarios to ensure proper remainder calculation behavior. +""" + +import testing +from decimojo.prelude import dm, Decimal, RoundingMode + + +fn test_basic_modulo() raises: + """Test basic modulo operations with positive integers.""" + print("Testing basic modulo operations...") + + # Test case 1: Simple modulo with remainder + var a1 = Decimal(10) + var b1 = Decimal(3) + var result1 = a1 % b1 + testing.assert_equal( + String(result1), "1", "10 % 3 should equal 1, got " + String(result1) + ) + + # Test case 2: Modulo with no remainder + var a2 = Decimal(10) + var b2 = Decimal(5) + var result2 = a2 % b2 + testing.assert_equal( + String(result2), "0", "10 % 5 should equal 0, got " + String(result2) + ) + + # Test case 3: Modulo with decimal values + var a3 = Decimal("10.5") + var b3 = Decimal("3.5") + var result3 = a3 % b3 + testing.assert_equal( + String(result3), + "0.0", + "10.5 % 3.5 should equal 0.0, got " + String(result3), + ) + + # Test case 4: Modulo with different decimal places + var a4 = Decimal("10.75") + var b4 = Decimal("2.5") + var result4 = a4 % b4 + testing.assert_equal( + String(result4), + "0.75", + "10.75 % 2.5 should equal 0.75, got " + String(result4), + ) + + # Test case 5: Modulo with modulus > dividend + var a5 = Decimal(3) + var b5 = Decimal(10) + var result5 = a5 % b5 + testing.assert_equal( + String(result5), "3", "3 % 10 should equal 3, got " + String(result5) + ) + + print("✓ Basic modulo operations tests passed!") + + +fn test_negative_modulo() raises: + """Test modulo operations involving negative numbers.""" + print("Testing modulo with negative numbers...") + + # Test case 1: Negative dividend, positive divisor + var a1 = Decimal(-10) + var b1 = Decimal(3) + var result1 = a1 % b1 + testing.assert_equal( + String(result1), "-1", "-10 % 3 should equal -1, got " + String(result1) + ) + + # Test case 2: Positive dividend, negative divisor + var a2 = Decimal(10) + var b2 = Decimal(-3) + var result2 = a2 % b2 + testing.assert_equal( + String(result2), "1", "10 % -3 should equal 1, got " + String(result2) + ) + + # Test case 3: Negative dividend, negative divisor + var a3 = Decimal(-10) + var b3 = Decimal(-3) + var result3 = a3 % b3 + testing.assert_equal( + String(result3), + "-1", + "-10 % -3 should equal -1, got " + String(result3), + ) + + # Test case 4: Decimal values, Negative dividend, positive divisor + var a4 = Decimal("-10.5") + var b4 = Decimal("3.5") + var result4 = a4 % b4 + testing.assert_equal( + String(result4), + "0.0", + "-10.5 % 3.5 should equal 0.0, got " + String(result4), + ) + + # Test case 5: Decimal values with remainder, Negative dividend, positive divisor + var a5 = Decimal("-10.5") + var b5 = Decimal("3") + var result5 = a5 % b5 + testing.assert_equal( + String(result5), + "-1.5", + "-10.5 % 3 should equal -1.5, got " + String(result5), + ) + + # Test case 6: Decimal values with remainder, Positive dividend, negative divisor + var a6 = Decimal("10.5") + var b6 = Decimal("-3") + var result6 = a6 % b6 + testing.assert_equal( + String(result6), + "1.5", + "10.5 % -3 should equal 1.5, got " + String(result6), + ) + + print("✓ Negative number modulo tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for modulo operation.""" + print("Testing modulo edge cases...") + + # Test case 1: Modulo by 1 + var a1 = Decimal(10) + var b1 = Decimal(1) + var result1 = a1 % b1 + testing.assert_equal( + String(result1), "0", "10 % 1 should equal 0, got " + String(result1) + ) + + # Test case 2: Zero dividend + var a2 = Decimal(0) + var b2 = Decimal(5) + var result2 = a2 % b2 + testing.assert_equal( + String(result2), "0", "0 % 5 should equal 0, got " + String(result2) + ) + + # Test case 3: Modulo with a decimal < 1 + var a3 = Decimal(10) + var b3 = Decimal("0.3") + var result3 = a3 % b3 + testing.assert_equal( + String(result3), + "0.1", + "10 % 0.3 should equal 0.1, got " + String(result3), + ) + + # Test case 4: Modulo by zero (should raise error) + var a4 = Decimal(10) + var b4 = Decimal(0) + var exception_caught = False + try: + var _result4 = a4 % b4 + testing.assert_equal(True, False, "Modulo by zero should raise error") + except: + exception_caught = True + testing.assert_equal( + exception_caught, True, "Modulo by zero should raise error" + ) + + # Test case 5: Large number modulo + var a5 = Decimal("1000000007") + var b5 = Decimal("13") + var result5 = a5 % b5 + testing.assert_equal( + String(result5), "6", "1000000007 % 13 calculated incorrectly" + ) + + # Test case 6: Small number modulo + var a6 = Decimal("0.0000023") + var b6 = Decimal("0.0000007") + var result6 = a6 % b6 + testing.assert_equal( + String(result6), + "0.0000002", + "0.0000023 % 0.0000007 calculated incorrectly", + ) + + # Test case 7: Equal values + var a7 = Decimal("7.5") + var b7 = Decimal("7.5") + var result7 = a7 % b7 + testing.assert_equal( + String(result7), + "0.0", + "7.5 % 7.5 should equal 0.0, got " + String(result7), + ) + + print("✓ Edge cases tests passed!") + + +fn test_mathematical_relationships() raises: + """Test mathematical relationships involving modulo.""" + print("Testing mathematical relationships...") + + # Test case 1: a = (a // b) * b + (a % b) + var a1 = Decimal(10) + var b1 = Decimal(3) + var floor_div = a1 // b1 + var mod_result = a1 % b1 + var reconstructed = floor_div * b1 + mod_result + testing.assert_equal( + String(reconstructed), + String(a1), + "a should equal (a // b) * b + (a % b)", + ) + + # Test case 2: 0 <= (a % b) < b for positive b + var a2 = Decimal("10.5") + var b2 = Decimal("3.2") + var mod_result2 = a2 % b2 + testing.assert_true( + (mod_result2 >= Decimal(0)) and (mod_result2 < b2), + "For positive b, 0 <= (a % b) < b should hold", + ) + + # Test case 3: Relationship with negative values + var a3 = Decimal(-10) + var b3 = Decimal(3) + var floor_div3 = a3 // b3 + var mod_result3 = a3 % b3 + var reconstructed3 = floor_div3 * b3 + mod_result3 + testing.assert_equal( + String(reconstructed3), + String(a3), + "a should equal (a // b) * b + (a % b) with negative values", + ) + + # Test case 4: a % b for negative b + var a4 = Decimal("10.5") + var b4 = Decimal("-3.2") + var mod_result4 = a4 % b4 + testing.assert_true( + mod_result4 == Decimal("0.9"), + "10.5 % -3.2 should equal 0.9, got " + String(mod_result4), + ) + + # Test case 5: (a % b) % b = a % b + var a5 = Decimal(17) + var b5 = Decimal(5) + var mod_once = a5 % b5 + var mod_twice = mod_once % b5 + testing.assert_equal( + String(mod_once), String(mod_twice), "(a % b) % b should equal a % b" + ) + + print("✓ Mathematical relationships tests passed!") + + +fn test_consistency_with_floor_division() raises: + """Test consistency between modulo and floor division operations.""" + print("Testing consistency with floor division...") + + # Test case 1: a % b and a - (a // b) * b + var a1 = Decimal(10) + var b1 = Decimal(3) + var mod_result = a1 % b1 + var floor_div = a1 // b1 + var calc_mod = a1 - floor_div * b1 + testing.assert_equal( + String(mod_result), + String(calc_mod), + "a % b should equal a - (a // b) * b", + ) + + # Test case 2: Consistency with negative values + var a2 = Decimal(-10) + var b2 = Decimal(3) + var mod_result2 = a2 % b2 + var floor_div2 = a2 // b2 + var calc_mod2 = a2 - floor_div2 * b2 + testing.assert_equal( + String(mod_result2), + String(calc_mod2), + "a % b should equal a - (a // b) * b with negative values", + ) + + # Test case 3: Consistency with decimal values + var a3 = Decimal("10.5") + var b3 = Decimal("2.5") + var mod_result3 = a3 % b3 + var floor_div3 = a3 // b3 + var calc_mod3 = a3 - floor_div3 * b3 + testing.assert_equal( + String(mod_result3), + String(calc_mod3), + "a % b should equal a - (a // b) * b with decimal values", + ) + + # Test case 4: Consistency with mixed positive and negative values + var a4 = Decimal(10) + var b4 = Decimal(-3) + var mod_result4 = a4 % b4 + var floor_div4 = a4 // b4 + var calc_mod4 = a4 - floor_div4 * b4 + testing.assert_equal( + String(mod_result4), + String(calc_mod4), + "a % b should equal a - (a // b) * b with mixed signs", + ) + + print("✓ Consistency with floor division tests passed!") + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Decimal Modulo Tests") + print("=========================================") + + run_test_with_error_handling( + test_basic_modulo, "Basic modulo operations test" + ) + run_test_with_error_handling( + test_negative_modulo, "Negative number modulo test" + ) + run_test_with_error_handling(test_edge_cases, "Edge cases test") + run_test_with_error_handling( + test_mathematical_relationships, "Mathematical relationships test" + ) + run_test_with_error_handling( + test_consistency_with_floor_division, + "Consistency with floor division test", + ) + + print("All modulo tests passed!") diff --git a/tests/test_round.mojo b/tests/test_round.mojo index 90a80a3..2208acc 100644 --- a/tests/test_round.mojo +++ b/tests/test_round.mojo @@ -43,19 +43,19 @@ fn test_different_rounding_modes() raises: var test_value = Decimal("123.456") # Test case 1: Round down (truncate) - var result1 = dm.round(test_value, 2, RoundingMode.DOWN()) + var result1 = dm.round(test_value, 2, RoundingMode.ROUND_DOWN) testing.assert_equal(String(result1), "123.45", "Rounding down") # Test case 2: Round up (away from zero) - var result2 = dm.round(test_value, 2, RoundingMode.UP()) + var result2 = dm.round(test_value, 2, RoundingMode.ROUND_UP) testing.assert_equal(String(result2), "123.46", "Rounding up") # Test case 3: Round half up - var result3 = dm.round(test_value, 2, RoundingMode.HALF_UP()) + var result3 = dm.round(test_value, 2, RoundingMode.ROUND_HALF_UP) testing.assert_equal(String(result3), "123.46", "Rounding half up") # Test case 4: Round half even (banker's rounding) - var result4 = dm.round(test_value, 2, RoundingMode.HALF_EVEN()) + var result4 = dm.round(test_value, 2, RoundingMode.ROUND_HALF_EVEN) testing.assert_equal(String(result4), "123.46", "Rounding half even") print("Rounding mode tests passed!") @@ -68,25 +68,25 @@ fn test_edge_cases() raises: var half_value = Decimal("123.5") testing.assert_equal( - String(dm.round(half_value, 0, RoundingMode.DOWN())), + String(dm.round(half_value, 0, RoundingMode.ROUND_DOWN)), "123", "Rounding 0.5 down", ) testing.assert_equal( - String(dm.round(half_value, 0, RoundingMode.UP())), + String(dm.round(half_value, 0, RoundingMode.ROUND_UP)), "124", "Rounding 0.5 up", ) testing.assert_equal( - String(dm.round(half_value, 0, RoundingMode.HALF_UP())), + String(dm.round(half_value, 0, RoundingMode.ROUND_HALF_UP)), "124", "Rounding 0.5 half up", ) testing.assert_equal( - String(dm.round(half_value, 0, RoundingMode.HALF_EVEN())), + String(dm.round(half_value, 0, RoundingMode.ROUND_HALF_EVEN)), "124", "Rounding 0.5 half even (even is 124)", ) @@ -94,7 +94,7 @@ fn test_edge_cases() raises: # Another test with half to even value var half_even_value = Decimal("124.5") testing.assert_equal( - String(dm.round(half_even_value, 0, RoundingMode.HALF_EVEN())), + String(dm.round(half_even_value, 0, RoundingMode.ROUND_HALF_EVEN)), "124", "Rounding 124.5 half even (even is 124)", ) @@ -113,19 +113,19 @@ fn test_edge_cases() raises: var negative_value = Decimal("-123.456") testing.assert_equal( - String(dm.round(negative_value, 2, RoundingMode.DOWN())), + String(dm.round(negative_value, 2, RoundingMode.ROUND_DOWN)), "-123.45", "Rounding negative number down", ) testing.assert_equal( - String(dm.round(negative_value, 2, RoundingMode.UP())), + String(dm.round(negative_value, 2, RoundingMode.ROUND_UP)), "-123.46", "Rounding negative number up", ) testing.assert_equal( - String(dm.round(negative_value, 2, RoundingMode.HALF_EVEN())), + String(dm.round(negative_value, 2, RoundingMode.ROUND_HALF_EVEN)), "-123.46", "Rounding negative number half even", )