diff --git a/benches/bench.mojo b/benches/bench.mojo index fe2a245..623cb7e 100644 --- a/benches/bench.mojo +++ b/benches/bench.mojo @@ -7,6 +7,7 @@ from bench_from_float import main as bench_from_float from bench_from_string import main as bench_from_string from bench_from_int import main as bench_from_int from bench_round import main as bench_round +from bench_quantize import main as bench_quantize from bench_comparison import main as bench_comparison from bench_exp import main as bench_exp from bench_ln import main as bench_ln @@ -25,6 +26,7 @@ fn main() raises: bench_from_string() bench_from_int() bench_round() + bench_quantize() bench_comparison() bench_exp() bench_ln() diff --git a/benches/bench_quantize.mojo b/benches/bench_quantize.mojo new file mode 100644 index 0000000..ff68a6c --- /dev/null +++ b/benches/bench_quantize.mojo @@ -0,0 +1,442 @@ +""" +Comprehensive benchmarks for Decimal.quantize() method. +Compares performance against Python's decimal module with diverse test cases. +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_quantize_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_benchmark_quantize( + name: String, + value_str: String, + quant_str: String, + rounding_mode: RoundingMode, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal.quantize with Python Decimal.quantize. + + Args: + name: Name of the benchmark case. + value_str: String representation of the value to quantize. + quant_str: String representation of the quantizer. + rounding_mode: The rounding mode to use. + iterations: Number of iterations to run. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Value: " + value_str, log_file) + log_print("Quantizer: " + quant_str, log_file) + log_print("Rounding mode: " + String(rounding_mode), log_file) + + # Set up Mojo and Python values + var mojo_value = Decimal(value_str) + var mojo_quant = Decimal(quant_str) + var pydecimal = Python.import_module("decimal") + var py_value = pydecimal.Decimal(value_str) + var py_quant = pydecimal.Decimal(quant_str) + + # Map Mojo rounding mode to Python rounding mode + var py_rounding_mode: PythonObject + if rounding_mode == RoundingMode.ROUND_HALF_EVEN: + py_rounding_mode = pydecimal.ROUND_HALF_EVEN + elif rounding_mode == RoundingMode.ROUND_HALF_UP: + py_rounding_mode = pydecimal.ROUND_HALF_UP + elif rounding_mode == RoundingMode.ROUND_UP: + py_rounding_mode = pydecimal.ROUND_UP + elif rounding_mode == RoundingMode.ROUND_DOWN: + py_rounding_mode = pydecimal.ROUND_DOWN + else: + py_rounding_mode = pydecimal.ROUND_HALF_EVEN # Default + + # Execute the operations once to verify correctness + try: + var mojo_result = mojo_value.quantize(mojo_quant, rounding_mode) + var py_result = py_value.quantize(py_quant, rounding=py_rounding_mode) + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + _ = mojo_value.quantize(mojo_quant, rounding_mode) + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Prevent division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + _ = py_value.quantize(py_quant, rounding=py_rounding_mode) + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo quantize(): " + String(mojo_time) + " ns per iteration", + log_file, + ) + log_print( + "Python quantize():" + String(python_time) + " ns per iteration", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + except e: + log_print("Error occurred during benchmark: " + String(e), log_file) + log_print("Skipping this benchmark case", log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo quantize() Method Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 10000 # Higher iterations as this operation should be fast + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning quantize() method benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Case 1: Basic quantization - rounding to 2 decimal places + run_benchmark_quantize( + "Round to 2 decimal places", + "3.14159", + "0.01", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 2: Round to integer + run_benchmark_quantize( + "Round to integer", + "42.7", + "1", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 3: Increase precision (add trailing zeros) + run_benchmark_quantize( + "Increase precision", + "5.5", + "0.001", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 4: Decrease precision (round) + run_benchmark_quantize( + "Decrease precision", + "123.456789", + "0.01", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 5: Different exponent patterns + run_benchmark_quantize( + "Different exponent pattern", + "9.876", + "1.00", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 6: ROUND_HALF_UP rounding mode + run_benchmark_quantize( + "ROUND_HALF_UP mode", + "3.5", + "1", + RoundingMode.ROUND_HALF_UP, + iterations, + log_file, + speedup_factors, + ) + + # Case 7: ROUND_DOWN rounding mode + run_benchmark_quantize( + "ROUND_DOWN mode", + "3.9", + "1", + RoundingMode.ROUND_DOWN, + iterations, + log_file, + speedup_factors, + ) + + # Case 8: ROUND_UP rounding mode + run_benchmark_quantize( + "ROUND_UP mode", + "3.1", + "1", + RoundingMode.ROUND_UP, + iterations, + log_file, + speedup_factors, + ) + + # Case 9: Negative number - ROUND_HALF_EVEN + run_benchmark_quantize( + "Negative number - ROUND_HALF_EVEN", + "-1.5", + "1", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 10: Quantizing zero + run_benchmark_quantize( + "Quantizing zero", + "0", + "0.001", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 11: Quantizing to same exponent (no change) + run_benchmark_quantize( + "Quantizing to same exponent", + "123.45", + "0.01", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 12: Numbers that need significant rounding + run_benchmark_quantize( + "Significant rounding", + "9.9999", + "1", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 13: Very small number + run_benchmark_quantize( + "Very small number", + "0.0000001", + "0.001", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 14: Banker's rounding for 2.5 + run_benchmark_quantize( + "Banker's rounding (2.5)", + "2.5", + "1", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 15: Quantizing with negative exponent + run_benchmark_quantize( + "Quantizing to tens place", + "123.456", + "10", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 16: High precision + run_benchmark_quantize( + "High precision quantizing", + "3.1415926535", + "0.000001", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 17: Rounding to hundreds + run_benchmark_quantize( + "Rounding to hundreds", + "750", + "100", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 18: Value very close to rounding threshold + run_benchmark_quantize( + "Value close to threshold", + "0.9999999", + "1", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 19: Pi to 2 decimal places + run_benchmark_quantize( + "Pi to 2 decimal places", + "3.14159265358979323", + "0.01", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Case 20: Zero with trailing zeros + run_benchmark_quantize( + "Zero with trailing zeros", + "0.0", + "0.0000", + RoundingMode.ROUND_HALF_EVEN, + iterations, + log_file, + speedup_factors, + ) + + # Calculate average speedup factor (ignoring any cases that might have failed) + if len(speedup_factors) > 0: + var sum_speedup: Float64 = 0.0 + for i in range(len(speedup_factors)): + sum_speedup += speedup_factors[i] + var average_speedup = sum_speedup / Float64(len(speedup_factors)) + + # Display summary + log_print("\n=== quantize() Method Benchmark Summary ===", log_file) + log_print( + "Benchmarked: " + + String(len(speedup_factors)) + + " different quantize() cases", + log_file, + ) + log_print( + "Each case ran: " + String(iterations) + " iterations", log_file + ) + log_print( + "Average speedup: " + String(average_speedup) + "×", log_file + ) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format( + i + 1, round(speedup_factors[i], 2) + ), + log_file, + ) + else: + log_print("\nNo valid benchmark cases were completed", log_file) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/mojoproject.toml b/mojoproject.toml index 1c4c2be..76cab69 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -8,6 +8,9 @@ platforms = ["osx-arm64", "linux-64"] readme = "README.md" version = "0.1.0" +[dependencies] +max = ">=25.1,<25.3" + [tasks] # format the code format = "magic run mojo format ./" @@ -34,6 +37,7 @@ test_divide = "magic run package && magic run mojo test tests/test_divide.mojo & test_sqrt = "magic run package && magic run mojo test tests/test_sqrt.mojo && magic run delete_package" test_root = "magic run package && magic run mojo test tests/test_root.mojo && magic run delete_package" test_round = "magic run package && magic run mojo test tests/test_round.mojo && magic run delete_package" +test_quantize = "magic run package && magic run mojo test tests/test_quantize.mojo && magic run delete_package" test_from_components = "magic run package && magic run mojo test tests/test_from_components.mojo && magic run delete_package" test_from_float = "magic run package && magic run mojo test tests/test_from_float.mojo && magic run delete_package" test_from_string = "magic run package && magic run mojo test tests/test_from_string.mojo && magic run delete_package" @@ -55,6 +59,7 @@ bench_divide = "magic run package && cd benches && magic run mojo bench_divide.m bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo && cd .. && magic run delete_package" bench_root = "magic run package && cd benches && magic run mojo bench_root.mojo && cd .. && magic run delete_package" bench_round = "magic run package && cd benches && magic run mojo bench_round.mojo && cd .. && magic run delete_package" +bench_quantize = "magic run package && cd benches && magic run mojo bench_quantize.mojo && cd .. && magic run delete_package" bench_from_float = "magic run package && cd benches && magic run mojo bench_from_float.mojo && cd .. && magic run delete_package" bench_from_string = "magic run package && cd benches && magic run mojo bench_from_string.mojo && cd .. && magic run delete_package" bench_from_int = "magic run package && cd benches && magic run mojo bench_from_int.mojo && cd .. && magic run delete_package" @@ -63,10 +68,3 @@ bench_exp = "magic run package && cd benches && magic run mojo bench_exp.mojo && bench_ln = "magic run package && cd benches && magic run mojo bench_ln.mojo && cd .. && magic run delete_package" bench_log10 = "magic run package && cd benches && magic run mojo bench_log10.mojo && cd .. && magic run delete_package" bench_power = "magic run package && cd benches && magic run mojo bench_power.mojo && cd .. && magic run delete_package" - -# before commit -final = "magic run test && magic run bench" -f = "clear && magic run final" - -[dependencies] -max = ">=25.1,<25.3" \ No newline at end of file diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 3318eb6..3ed4b28 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -51,7 +51,7 @@ from .comparison import ( from .exponential import power, root, sqrt, exp, ln, log, log10 -from .rounding import round +from .rounding import round, quantize from .special import ( factorial, diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index 0a406a1..fcd32c7 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -1368,23 +1368,23 @@ struct Decimal( Compared to `__round__`, this method: (1) Allows specifying the rounding mode. (2) Raises an error if the operation would result in overflow. - - Args: - ndigits: The number of decimal places to round to. - Default is 0. - rounding_mode: The rounding mode to use. - Default is RoundingMode.ROUND_HALF_EVEN. - - Returns: - The rounded Decimal value. - - Raises: - Error: If calling `round()` failed. + See `round()` for more information. """ return decimojo.rounding.round( self, ndigits=ndigits, rounding_mode=rounding_mode ) + @always_inline + fn quantize( + self, + exp: Decimal, + rounding_mode: RoundingMode = RoundingMode.ROUND_HALF_EVEN, + ) raises -> Self: + """Quantizes this Decimal to the specified exponent. + See `quantize()` for more information. + """ + return decimojo.rounding.quantize(self, exp, rounding_mode) + @always_inline fn exp(self) raises -> Self: """Calculates the exponential of this Decimal. diff --git a/src/decimojo/rounding.mojo b/src/decimojo/rounding.mojo index abe8a03..e7ae548 100644 --- a/src/decimojo/rounding.mojo +++ b/src/decimojo/rounding.mojo @@ -165,3 +165,59 @@ fn round( # Add a fallback raise even if it seems unreachable testing.assert_true(False, "Unreachable code path reached") return number + + +fn quantize( + value: Decimal, + exp: Decimal, + rounding_mode: RoundingMode = RoundingMode.ROUND_HALF_EVEN, +) raises -> Decimal: + """Rounds the value according to the exponent of the second operand. + Unlike `round()`, the scale is determined by the scale of the second + operand, not a number of digits. `quantize()` returns the same value as + `round()` when the scale of the second operand is non-negative. + + Args: + value: The Decimal value to quantize. + exp: A Decimal whose scale (exponent) will be used for the result. + rounding_mode: The rounding mode to use. + Defaults to ROUND_HALF_EVEN (banker's rounding). + + Returns: + A new Decimal with the same value as the first operand (except for + rounding) and the same scale (exponent) as the second operand. + + Raises: + Error: If the resulting number doesn't fit within the valid range. + + Examples: + + ```mojo + from decimojo import Decimal + _ = Decimal("1.2345").quantize(Decimal("0.001")) # -> Decimal("1.234") + _ = Decimal("1.2345").quantize(Decimal("0.01")) # -> Decimal("1.23") + _ = Decimal("1.2345").quantize(Decimal("0.1")) # -> Decimal("1.2") + _ = Decimal("1.2345").quantize(Decimal("1")) # -> Decimal("1") + _ = Decimal("1.2345").quantize(Decimal("10")) # -> Decimal("1") + # Compare with round() + _ = Decimal("1.2345").round(-1) # -> Decimal("0") + ``` + End of examples. + """ + + # Determine the scale of the target exponent + var target_scale = exp.scale() + # Determine the scale of the value + var value_scale = value.scale() + + # If the scales are already the same, no quantization needed + if target_scale == value_scale: + return value + + # If the target scale is non-negative, round the value to the target scale + elif target_scale >= 0: + return round(value, target_scale, rounding_mode) + + # If the target scale is negative, round the value to integer + else: + return round(value, 0, rounding_mode) diff --git a/tests/test_quantize.mojo b/tests/test_quantize.mojo new file mode 100644 index 0000000..ce94fa9 --- /dev/null +++ b/tests/test_quantize.mojo @@ -0,0 +1,647 @@ +""" +Comprehensive tests for the Decimal.quantize() method. +Tests various scenarios to ensure proper quantization behavior and compatibility +with Python's decimal module implementation. +""" + +import testing +from python import Python, PythonObject +from decimojo.prelude import dm, Decimal, RoundingMode + + +fn test_basic_quantization() raises: + """Test basic quantization with different scales.""" + print("Testing basic quantization...") + + var pydecimal = Python.import_module("decimal") + pydecimal.getcontext().prec = 28 # Match DeciMojo's precision + + var value1 = Decimal("3.14159") + var quant1 = Decimal("0.01") + var result1 = value1.quantize(quant1) + var py_value1 = pydecimal.Decimal("3.14159") + var py_quant1 = pydecimal.Decimal("0.01") + var py_result1 = py_value1.quantize(py_quant1) + + testing.assert_equal( + String(result1), + String(py_result1), + "Quantizing 3.14159 to 0.01 gave incorrect result: " + String(result1), + ) + + var value2 = Decimal("42.7") + var quant2 = Decimal("1") + var result2 = value2.quantize(quant2) + var py_value2 = pydecimal.Decimal("42.7") + var py_quant2 = pydecimal.Decimal("1") + var py_result2 = py_value2.quantize(py_quant2) + + testing.assert_equal( + String(result2), + String(py_result2), + "Quantizing 42.7 to 1 gave incorrect result: " + String(result2), + ) + + var value3 = Decimal("5.5") + var quant3 = Decimal("0.001") + var result3 = value3.quantize(quant3) + var py_value3 = pydecimal.Decimal("5.5") + var py_quant3 = pydecimal.Decimal("0.001") + var py_result3 = py_value3.quantize(py_quant3) + + testing.assert_equal( + String(result3), + String(py_result3), + "Quantizing 5.5 to 0.001 gave incorrect result: " + String(result3), + ) + + var value4 = Decimal("123.456789") + var quant4 = Decimal("0.01") + var result4 = value4.quantize(quant4) + var py_value4 = pydecimal.Decimal("123.456789") + var py_quant4 = pydecimal.Decimal("0.01") + var py_result4 = py_value4.quantize(py_quant4) + + testing.assert_equal( + String(result4), + String(py_result4), + "Quantizing 123.456789 to 0.01 gave incorrect result: " + + String(result4), + ) + + var value5 = Decimal("9.876") + var quant5 = Decimal("1.00") + var result5 = value5.quantize(quant5) + var py_value5 = pydecimal.Decimal("9.876") + var py_quant5 = pydecimal.Decimal("1.00") + var py_result5 = py_value5.quantize(py_quant5) + + testing.assert_equal( + String(result5), + String(py_result5), + "Quantizing 9.876 to 1.00 gave incorrect result: " + String(result5), + ) + + print("✓ Basic quantization tests passed!") + + +fn test_rounding_modes() raises: + """Test quantization with different rounding modes.""" + print("Testing quantization with different rounding modes...") + + var pydecimal = Python.import_module("decimal") + pydecimal.getcontext().prec = 28 + + var test_value = Decimal("3.5") + var quantizer = Decimal("1") + var py_value = pydecimal.Decimal("3.5") + var py_quantizer = pydecimal.Decimal("1") + + var result1 = test_value.quantize(quantizer, RoundingMode.ROUND_HALF_EVEN) + var py_result1 = py_value.quantize( + py_quantizer, rounding=pydecimal.ROUND_HALF_EVEN + ) + testing.assert_equal( + String(result1), + String(py_result1), + "ROUND_HALF_EVEN gave incorrect result: " + String(result1), + ) + + var result2 = test_value.quantize(quantizer, RoundingMode.ROUND_HALF_UP) + var py_result2 = py_value.quantize( + py_quantizer, rounding=pydecimal.ROUND_HALF_UP + ) + testing.assert_equal( + String(result2), + String(py_result2), + "ROUND_HALF_UP gave incorrect result: " + String(result2), + ) + + var result3 = test_value.quantize(quantizer, RoundingMode.ROUND_DOWN) + var py_result3 = py_value.quantize( + py_quantizer, rounding=pydecimal.ROUND_DOWN + ) + testing.assert_equal( + String(result3), + String(py_result3), + "ROUND_DOWN gave incorrect result: " + String(result3), + ) + + var result4 = test_value.quantize(quantizer, RoundingMode.ROUND_UP) + var py_result4 = py_value.quantize( + py_quantizer, rounding=pydecimal.ROUND_UP + ) + testing.assert_equal( + String(result4), + String(py_result4), + "ROUND_UP gave incorrect result: " + String(result4), + ) + + var neg_test_value = Decimal("-3.5") + var result5 = neg_test_value.quantize(quantizer, RoundingMode.ROUND_DOWN) + var py_neg_value = pydecimal.Decimal("-3.5") + var py_result5 = py_neg_value.quantize( + py_quantizer, rounding=pydecimal.ROUND_DOWN + ) + testing.assert_equal( + String(result5), + String(py_result5), + "ROUND_DOWN with negative gave incorrect result: " + String(result5), + ) + + var result6 = neg_test_value.quantize(quantizer, RoundingMode.ROUND_UP) + var py_result6 = py_neg_value.quantize( + py_quantizer, rounding=pydecimal.ROUND_UP + ) + testing.assert_equal( + String(result6), + String(py_result6), + "ROUND_UP with negative gave incorrect result: " + String(result6), + ) + + print("✓ Rounding mode tests passed!") + + +fn test_edge_cases() raises: + """Test edge cases for quantization.""" + print("Testing quantization edge cases...") + + var pydecimal = Python.import_module("decimal") + pydecimal.getcontext().prec = 28 + + var zero = Decimal("0") + var quant1 = Decimal("0.001") + var result1 = zero.quantize(quant1) + var py_zero = pydecimal.Decimal("0") + var py_quant1 = pydecimal.Decimal("0.001") + var py_result1 = py_zero.quantize(py_quant1) + + testing.assert_equal( + String(result1), + String(py_result1), + "Quantizing 0 to 0.001 gave incorrect result: " + String(result1), + ) + + var value2 = Decimal("123.45") + var quant2 = Decimal("0.01") + var result2 = value2.quantize(quant2) + var py_value2 = pydecimal.Decimal("123.45") + var py_quant2 = pydecimal.Decimal("0.01") + var py_result2 = py_value2.quantize(py_quant2) + + testing.assert_equal( + String(result2), + String(py_result2), + "Quantizing to same exponent gave incorrect result: " + String(result2), + ) + + var value3 = Decimal("9.9999") + var quant3 = Decimal("1") + var result3 = value3.quantize(quant3) + var py_value3 = pydecimal.Decimal("9.9999") + var py_quant3 = pydecimal.Decimal("1") + var py_result3 = py_value3.quantize(py_quant3) + + testing.assert_equal( + String(result3), + String(py_result3), + "Rounding 9.9999 to 1 gave incorrect result: " + String(result3), + ) + + var value4 = Decimal("0.0000001") + var quant4 = Decimal("0.001") + var result4 = value4.quantize(quant4) + var py_value4 = pydecimal.Decimal("0.0000001") + var py_quant4 = pydecimal.Decimal("0.001") + var py_result4 = py_value4.quantize(py_quant4) + + testing.assert_equal( + String(result4), + String(py_result4), + "Quantizing very small number gave incorrect result: " + + String(result4), + ) + + var value5 = Decimal("-1.5") + var quant5 = Decimal("1") + var result5 = value5.quantize(quant5, RoundingMode.ROUND_HALF_EVEN) + var py_value5 = pydecimal.Decimal("-1.5") + var py_quant5 = pydecimal.Decimal("1") + var py_result5 = py_value5.quantize( + py_quant5, rounding=pydecimal.ROUND_HALF_EVEN + ) + + testing.assert_equal( + String(result5), + String(py_result5), + "Quantizing -1.5 with ROUND_HALF_EVEN gave incorrect result: " + + String(result5), + ) + + print("✓ Edge cases tests passed!") + + +fn test_special_cases() raises: + """Test special cases for quantization.""" + print("Testing special quantization cases...") + + var pydecimal = Python.import_module("decimal") + pydecimal.getcontext().prec = 28 + + var value1 = Decimal("12.34") + var quant1 = Decimal("0.0000") + var result1 = value1.quantize(quant1) + var py_value1 = pydecimal.Decimal("12.34") + var py_quant1 = pydecimal.Decimal("0.0000") + var py_result1 = py_value1.quantize(py_quant1) + + testing.assert_equal( + String(result1), + String(py_result1), + "Increasing precision gave incorrect result: " + String(result1), + ) + + var value2 = Decimal("2.5") + var quant2 = Decimal("1") + var result2 = value2.quantize(quant2, RoundingMode.ROUND_HALF_EVEN) + var py_value2 = pydecimal.Decimal("2.5") + var py_quant2 = pydecimal.Decimal("1") + var py_result2 = py_value2.quantize( + py_quant2, rounding=pydecimal.ROUND_HALF_EVEN + ) + + testing.assert_equal( + String(result2), + String(py_result2), + "Banker's rounding for 2.5 gave incorrect result: " + String(result2), + ) + + var value3 = Decimal("123.456") + var quant3 = Decimal("10") + var result3 = value3.quantize(quant3) + var py_value3 = pydecimal.Decimal("123.456") + var py_quant3 = pydecimal.Decimal("10") + var py_result3 = py_value3.quantize(py_quant3) + + testing.assert_equal( + String(result3), + String(py_result3), + "Quantizing with negative exponent gave incorrect result: " + + String(result3), + ) + + var value4 = Decimal("3.1415926535") + var quant4 = Decimal("0.00000001") + var result4 = value4.quantize(quant4) + var py_value4 = pydecimal.Decimal("3.1415926535") + var py_quant4 = pydecimal.Decimal("0.00000001") + var py_result4 = py_value4.quantize(py_quant4) + + testing.assert_equal( + String(result4), + String(py_result4), + "Very precise quantization gave incorrect result: " + String(result4), + ) + + var value5 = Decimal("123.456") + var quant5 = Decimal("1") + var result5 = value5.quantize(quant5) + var py_value5 = pydecimal.Decimal("123.456") + var py_quant5 = pydecimal.Decimal("1") + var py_result5 = py_value5.quantize(py_quant5) + + testing.assert_equal( + String(result5), + String(py_result5), + "Quantizing to integer gave incorrect result: " + String(result5), + ) + + print("✓ Special cases tests passed!") + + +fn test_quantize_exceptions() raises: + """Test exception conditions for quantize().""" + print("Testing quantize exceptions...") + + var pydecimal = Python.import_module("decimal") + pydecimal.getcontext().prec = 28 + + var exception_caught = False + try: + var value1 = Decimal("123.456") + var quant1 = Decimal("1000") + var _result1 = value1.quantize(quant1) + except: + exception_caught = True + + var py_exception_caught = False + try: + var py_value1 = pydecimal.Decimal("123.456") + var py_quant1 = pydecimal.Decimal("1000") + var _py_result1 = py_value1.quantize(py_quant1) + except: + py_exception_caught = True + + testing.assert_equal( + exception_caught, + py_exception_caught, + ( + "Exception handling for invalid quantization doesn't match Python's" + " behavior" + ), + ) + + print("✓ Exception tests passed!") + + +fn test_comprehensive_comparison() raises: + """Test a wide range of values to ensure compatibility with Python's decimal. + """ + print("Testing comprehensive comparison with Python's decimal...") + + # Set up Python decimal + var pydecimal = Python.import_module("decimal") + pydecimal.getcontext().prec = 28 # Match DeciMojo's precision + + # Define rounding modes to test + var mojo_round_half_even = RoundingMode.ROUND_HALF_EVEN + var mojo_round_half_up = RoundingMode.ROUND_HALF_UP + var mojo_round_down = RoundingMode.ROUND_DOWN + var mojo_round_up = RoundingMode.ROUND_UP + + var py_round_half_even = pydecimal.ROUND_HALF_EVEN + var py_round_half_up = pydecimal.ROUND_HALF_UP + var py_round_down = pydecimal.ROUND_DOWN + var py_round_up = pydecimal.ROUND_UP + + # Instead of looping through lists, test each case explicitly + # Test case 1: Zero with integer quantizer + test_single_quantize_case( + "0", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "0", "1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "0", "1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case("0", "1", mojo_round_up, py_round_up, pydecimal) + + # Test case 2: Decimal with 2 decimal places quantizer + test_single_quantize_case( + "1.23456", "0.01", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "1.23456", "0.01", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "1.23456", "0.01", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "1.23456", "0.01", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 3: Decimal with 1 decimal place quantizer + test_single_quantize_case( + "9.999", "0.1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "9.999", "0.1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "9.999", "0.1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "9.999", "0.1", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 4: Negative value with integer quantizer + test_single_quantize_case( + "-0.5", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "-0.5", "1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "-0.5", "1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "-0.5", "1", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 5: Small value with larger precision + test_single_quantize_case( + "0.0001", "0.01", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "0.0001", "0.01", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "0.0001", "0.01", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "0.0001", "0.01", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 6: Large value with integer quantizer + test_single_quantize_case( + "1234.5678", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "1234.5678", "1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "1234.5678", "1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "1234.5678", "1", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 7: Rounding to larger precision + test_single_quantize_case( + "99.99", "100", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "99.99", "100", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "99.99", "100", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "99.99", "100", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 8: Very small value with small precision + test_single_quantize_case( + "0.0000001", + "0.00001", + mojo_round_half_even, + py_round_half_even, + pydecimal, + ) + test_single_quantize_case( + "0.0000001", "0.00001", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "0.0000001", "0.00001", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "0.0000001", "0.00001", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 9: Large value with 1 decimal place + test_single_quantize_case( + "987654.321", "0.1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "987654.321", "0.1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "987654.321", "0.1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "987654.321", "0.1", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 10: Testing banker's rounding + test_single_quantize_case( + "1.5", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "2.5", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + + # Test case 11: Testing rounding to thousands + test_single_quantize_case( + "10000", "1000", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "10000", "1000", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "10000", "1000", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "10000", "1000", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 12: Rounding up very close value + test_single_quantize_case( + "0.999999", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "0.999999", "1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "0.999999", "1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "0.999999", "1", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 13: Pi with very high precision + test_single_quantize_case( + "3.14159265358979323", + "0.00000000001", + mojo_round_half_even, + py_round_half_even, + pydecimal, + ) + + # Test case 14: Negative value rounding + test_single_quantize_case( + "-999.9", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + test_single_quantize_case( + "-999.9", "1", mojo_round_half_up, py_round_half_up, pydecimal + ) + test_single_quantize_case( + "-999.9", "1", mojo_round_down, py_round_down, pydecimal + ) + test_single_quantize_case( + "-999.9", "1", mojo_round_up, py_round_up, pydecimal + ) + + # Test case 15: Zero with trailing zeros + test_single_quantize_case( + "0.0", "0.0000", mojo_round_half_even, py_round_half_even, pydecimal + ) + + # Test case 16: Integer to integer + test_single_quantize_case( + "123", "1", mojo_round_half_even, py_round_half_even, pydecimal + ) + + print("✓ Comprehensive comparison tests passed!") + + +fn test_single_quantize_case( + value_str: String, + quant_str: String, + mojo_mode: RoundingMode, + py_mode: PythonObject, + pydecimal: PythonObject, +) raises: + """Test a single quantize case comparing Mojo and Python implementations.""" + + try: + var mojo_value = Decimal(value_str) + var mojo_quant = Decimal(quant_str) + var py_value = pydecimal.Decimal(value_str) + var py_quant = pydecimal.Decimal(quant_str) + + var mojo_result = mojo_value.quantize(mojo_quant, mojo_mode) + var py_result = py_value.quantize(py_quant, rounding=py_mode) + + testing.assert_equal( + String(mojo_result), + String(py_result), + String("Quantizing {} to {} gave incorrect result: {}").format( + value_str, quant_str, String(mojo_result) + ), + ) + except e: + print( + String("Exception occurred (expected): {} to {}").format( + value_str, quant_str + ) + ) + # Both implementations should either both succeed or both fail + + +fn run_test_with_error_handling( + test_fn: fn () raises -> None, test_name: String +) raises: + """Helper function to run a test function with error handling and reporting. + """ + try: + print("\n" + "=" * 50) + print("RUNNING: " + test_name) + print("=" * 50) + test_fn() + print("\n✓ " + test_name + " passed\n") + except e: + print("\n✗ " + test_name + " FAILED!") + print("Error message: " + String(e)) + raise e + + +fn main() raises: + print("=========================================") + print("Running Decimal.quantize() Tests") + print("=========================================") + + run_test_with_error_handling( + test_basic_quantization, "Basic quantization test" + ) + run_test_with_error_handling(test_rounding_modes, "Rounding modes test") + run_test_with_error_handling(test_edge_cases, "Edge cases test") + run_test_with_error_handling(test_special_cases, "Special cases test") + run_test_with_error_handling( + test_quantize_exceptions, "Exception handling test" + ) + run_test_with_error_handling( + test_comprehensive_comparison, "Comprehensive comparison test" + ) + + print("All Decimal.quantize() tests passed!")