diff --git a/.gitignore b/.gitignore index 2c86457..fa5c872 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ magic.lock # VSCode environments .vscode tempCodeRunnerFile.mojo +/temp*.mojo # macOS environments .DS_Store # log files diff --git a/benches/bench.mojo b/benches/bench.mojo index e231372..5e020a8 100644 --- a/benches/bench.mojo +++ b/benches/bench.mojo @@ -4,6 +4,8 @@ from bench_multiply import main as bench_multiply from bench_divide import main as bench_divide from bench_sqrt import main as bench_sqrt from bench_from_float import main as bench_from_float +from bench_from_string import main as bench_from_string +from bench_comparison import main as bench_comparison fn main() raises: @@ -13,3 +15,5 @@ fn main() raises: bench_divide() bench_sqrt() bench_from_float() + bench_from_string() + bench_comparison() diff --git a/benches/bench_comparison.mojo b/benches/bench_comparison.mojo new file mode 100644 index 0000000..427c5a9 --- /dev/null +++ b/benches/bench_comparison.mojo @@ -0,0 +1,741 @@ +""" +Comprehensive benchmarks for Decimal logical comparison operations. +Compares performance against Python's decimal module across diverse test cases. +Tests all comparison operators: >, >=, ==, <=, <, != +""" + +from decimojo.prelude import dm, Decimal, RoundingMode +from python import Python, PythonObject +from time import perf_counter_ns +import time +import os +from collections import List + + +fn open_log_file() raises -> PythonObject: + """ + Creates and opens a log file with a timestamp in the filename. + + Returns: + A file object opened for writing. + """ + var python = Python.import_module("builtins") + var datetime = Python.import_module("datetime") + + # Create logs directory if it doesn't exist + var log_dir = "./logs" + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Generate a timestamp for the filename + var timestamp = String(datetime.datetime.now().isoformat()) + var log_filename = log_dir + "/benchmark_comparison_" + timestamp + ".log" + + print("Saving benchmark results to:", log_filename) + return python.open(log_filename, "w") + + +fn log_print(msg: String, log_file: PythonObject) raises: + """ + Prints a message to both the console and the log file. + + Args: + msg: The message to print. + log_file: The file object to write to. + """ + print(msg) + log_file.write(msg + "\n") + log_file.flush() # Ensure the message is written immediately + + +fn run_comparison_benchmark( + name: String, + a_mojo: Decimal, + b_mojo: Decimal, + a_py: PythonObject, + b_py: PythonObject, + op: String, + iterations: Int, + log_file: PythonObject, + mut speedup_factors: List[Float64], +) raises: + """ + Run a benchmark comparing Mojo Decimal comparison with Python Decimal comparison. + + Args: + name: Name of the benchmark case. + a_mojo: First Mojo Decimal operand. + b_mojo: Second Mojo Decimal operand. + a_py: First Python Decimal operand. + b_py: Second Python Decimal operand. + op: Comparison operator as string (">", ">=", "==", "<=", "<", "!="). + iterations: Number of iterations for the benchmark. + log_file: File object for logging results. + speedup_factors: Mojo List to store speedup factors for averaging. + """ + log_print("\nBenchmark: " + name, log_file) + log_print("Operator: " + op, log_file) + log_print( + "Decimals: " + String(a_mojo) + " " + op + " " + String(b_mojo), + log_file, + ) + + # Execute the operations once to verify correctness + var mojo_result: Bool + var py_result: PythonObject + + if op == ">": + mojo_result = a_mojo > b_mojo + py_result = a_py > b_py + elif op == ">=": + mojo_result = a_mojo >= b_mojo + py_result = a_py >= b_py + elif op == "==": + mojo_result = a_mojo == b_mojo + py_result = a_py == b_py + elif op == "<=": + mojo_result = a_mojo <= b_mojo + py_result = a_py <= b_py + elif op == "<": + mojo_result = a_mojo < b_mojo + py_result = a_py < b_py + elif op == "!=": + mojo_result = a_mojo != b_mojo + py_result = a_py != b_py + else: + log_print("Error: Invalid operator '" + op + "'", log_file) + return + + # Display results for verification + log_print("Mojo result: " + String(mojo_result), log_file) + log_print("Python result: " + String(py_result), log_file) + + # Benchmark Mojo implementation + var t0 = perf_counter_ns() + for _ in range(iterations): + if op == ">": + _ = a_mojo > b_mojo + elif op == ">=": + _ = a_mojo >= b_mojo + elif op == "==": + _ = a_mojo == b_mojo + elif op == "<=": + _ = a_mojo <= b_mojo + elif op == "<": + _ = a_mojo < b_mojo + elif op == "!=": + _ = a_mojo != b_mojo + var mojo_time = (perf_counter_ns() - t0) / iterations + if mojo_time == 0: + mojo_time = 1 # Avoid division by zero + + # Benchmark Python implementation + t0 = perf_counter_ns() + for _ in range(iterations): + if op == ">": + _ = a_py > b_py + elif op == ">=": + _ = a_py >= b_py + elif op == "==": + _ = a_py == b_py + elif op == "<=": + _ = a_py <= b_py + elif op == "<": + _ = a_py < b_py + elif op == "!=": + _ = a_py != b_py + var python_time = (perf_counter_ns() - t0) / iterations + + # Calculate speedup factor + var speedup = python_time / mojo_time + speedup_factors.append(Float64(speedup)) + + # Print results with speedup comparison + log_print( + "Mojo Decimal: " + String(mojo_time) + " ns per operation", + log_file, + ) + log_print( + "Python Decimal: " + String(python_time) + " ns per operation", + log_file, + ) + log_print("Speedup factor: " + String(speedup), log_file) + + +fn main() raises: + # Open log file + var log_file = open_log_file() + var datetime = Python.import_module("datetime") + + # Create a Mojo List to store speedup factors for averaging later + var speedup_factors = List[Float64]() + + # Display benchmark header with system information + log_print("=== DeciMojo Logical Comparison Benchmark ===", log_file) + log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file) + + # Try to get system info + try: + var platform = Python.import_module("platform") + log_print( + "System: " + + String(platform.system()) + + " " + + String(platform.release()), + log_file, + ) + log_print("Processor: " + String(platform.processor()), log_file) + log_print( + "Python version: " + String(platform.python_version()), log_file + ) + except: + log_print("Could not retrieve system information", log_file) + + var iterations = 10000 + var pydecimal = Python().import_module("decimal") + + # Set Python decimal precision to match Mojo's + pydecimal.getcontext().prec = 28 + log_print( + "Python decimal precision: " + String(pydecimal.getcontext().prec), + log_file, + ) + log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file) + + # Define benchmark cases + log_print( + "\nRunning logical comparison benchmarks with " + + String(iterations) + + " iterations each", + log_file, + ) + + # Test Case 1: Equal integers + var case1_a_mojo = Decimal("100") + var case1_b_mojo = Decimal("100") + var case1_a_py = pydecimal.Decimal("100") + var case1_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Equal integers", + case1_a_mojo, + case1_b_mojo, + case1_a_py, + case1_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 2: Different integers + var case2_a_mojo = Decimal("100") + var case2_b_mojo = Decimal("200") + var case2_a_py = pydecimal.Decimal("100") + var case2_b_py = pydecimal.Decimal("200") + run_comparison_benchmark( + "Different integers (<)", + case2_a_mojo, + case2_b_mojo, + case2_a_py, + case2_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 3: Different integers (>) + run_comparison_benchmark( + "Different integers (>)", + case2_b_mojo, + case2_a_mojo, + case2_b_py, + case2_a_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 4: Equal decimals with different representations + var case4_a_mojo = Decimal("100.00") + var case4_b_mojo = Decimal("100") + var case4_a_py = pydecimal.Decimal("100.00") + var case4_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Equal decimal with different scales", + case4_a_mojo, + case4_b_mojo, + case4_a_py, + case4_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 5: Compare with zero + var case5_a_mojo = Decimal("0") + var case5_b_mojo = Decimal("-0.00") + var case5_a_py = pydecimal.Decimal("0") + var case5_b_py = pydecimal.Decimal("-0.00") + run_comparison_benchmark( + "Zero comparison (==)", + case5_a_mojo, + case5_b_mojo, + case5_a_py, + case5_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 6: Very small difference + var case6_a_mojo = Decimal("0.0000000000000000000000000001") + var case6_b_mojo = Decimal("0.0000000000000000000000000002") + var case6_a_py = pydecimal.Decimal("0.0000000000000000000000000001") + var case6_b_py = pydecimal.Decimal("0.0000000000000000000000000002") + run_comparison_benchmark( + "Very small difference (<)", + case6_a_mojo, + case6_b_mojo, + case6_a_py, + case6_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 7: Very large numbers + var case7_a_mojo = Decimal("9999999999999999999999999999") + var case7_b_mojo = Decimal("9999999999999999999999999998") + var case7_a_py = pydecimal.Decimal("9999999999999999999999999999") + var case7_b_py = pydecimal.Decimal("9999999999999999999999999998") + run_comparison_benchmark( + "Very large numbers (>)", + case7_a_mojo, + case7_b_mojo, + case7_a_py, + case7_b_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 8: Negative numbers + var case8_a_mojo = Decimal("-10") + var case8_b_mojo = Decimal("-20") + var case8_a_py = pydecimal.Decimal("-10") + var case8_b_py = pydecimal.Decimal("-20") + run_comparison_benchmark( + "Negative numbers (>)", + case8_a_mojo, + case8_b_mojo, + case8_a_py, + case8_b_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 9: Mixed sign comparison + var case9_a_mojo = Decimal("-10") + var case9_b_mojo = Decimal("10") + var case9_a_py = pydecimal.Decimal("-10") + var case9_b_py = pydecimal.Decimal("10") + run_comparison_benchmark( + "Mixed signs (<)", + case9_a_mojo, + case9_b_mojo, + case9_a_py, + case9_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 10: Not equal comparison + var case10_a_mojo = Decimal("99.99") + var case10_b_mojo = Decimal("100") + var case10_a_py = pydecimal.Decimal("99.99") + var case10_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Not equal (!=)", + case10_a_mojo, + case10_b_mojo, + case10_a_py, + case10_b_py, + "!=", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 11: Less than or equal (true because less) + var case11_a_mojo = Decimal("50") + var case11_b_mojo = Decimal("100") + var case11_a_py = pydecimal.Decimal("50") + var case11_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Less than or equal (<=, true because less)", + case11_a_mojo, + case11_b_mojo, + case11_a_py, + case11_b_py, + "<=", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 12: Less than or equal (true because equal) + var case12_a_mojo = Decimal("100") + var case12_b_mojo = Decimal("100") + var case12_a_py = pydecimal.Decimal("100") + var case12_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Less than or equal (<=, true because equal)", + case12_a_mojo, + case12_b_mojo, + case12_a_py, + case12_b_py, + "<=", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 13: Greater than or equal (true because greater) + var case13_a_mojo = Decimal("200") + var case13_b_mojo = Decimal("100") + var case13_a_py = pydecimal.Decimal("200") + var case13_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Greater than or equal (>=, true because greater)", + case13_a_mojo, + case13_b_mojo, + case13_a_py, + case13_b_py, + ">=", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 14: Greater than or equal (true because equal) + var case14_a_mojo = Decimal("100") + var case14_b_mojo = Decimal("100") + var case14_a_py = pydecimal.Decimal("100") + var case14_b_py = pydecimal.Decimal("100") + run_comparison_benchmark( + "Greater than or equal (>=, true because equal)", + case14_a_mojo, + case14_b_mojo, + case14_a_py, + case14_b_py, + ">=", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 15: Equal with high precision after decimal + var case15_a_mojo = Decimal("0.12345678901234567890123456789") + var case15_b_mojo = Decimal("0.12345678901234567890123456789") + var case15_a_py = pydecimal.Decimal("0.12345678901234567890123456789") + var case15_b_py = pydecimal.Decimal("0.12345678901234567890123456789") + run_comparison_benchmark( + "Equal high precision numbers", + case15_a_mojo, + case15_b_mojo, + case15_a_py, + case15_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 16: Almost equal high precision + var case16_a_mojo = Decimal("0.12345678901234567890123456780") + var case16_b_mojo = Decimal("0.12345678901234567890123456789") + var case16_a_py = pydecimal.Decimal("0.12345678901234567890123456780") + var case16_b_py = pydecimal.Decimal("0.12345678901234567890123456789") + run_comparison_benchmark( + "Almost equal high precision (<)", + case16_a_mojo, + case16_b_mojo, + case16_a_py, + case16_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 17: Equal but different trailing zeros + var case17_a_mojo = Decimal("1.10000") + var case17_b_mojo = Decimal("1.1") + var case17_a_py = pydecimal.Decimal("1.10000") + var case17_b_py = pydecimal.Decimal("1.1") + run_comparison_benchmark( + "Equal with different trailing zeros", + case17_a_mojo, + case17_b_mojo, + case17_a_py, + case17_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 18: Not equal with trailing zeros that matter + var case18_a_mojo = Decimal("1.10001") + var case18_b_mojo = Decimal("1.1") + var case18_a_py = pydecimal.Decimal("1.10001") + var case18_b_py = pydecimal.Decimal("1.1") + run_comparison_benchmark( + "Not equal with significant trailing digits", + case18_a_mojo, + case18_b_mojo, + case18_a_py, + case18_b_py, + "!=", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 19: Large positive vs small negative + var case19_a_mojo = Decimal("9999999") + var case19_b_mojo = Decimal("-0.000001") + var case19_a_py = pydecimal.Decimal("9999999") + var case19_b_py = pydecimal.Decimal("-0.000001") + run_comparison_benchmark( + "Large positive vs small negative (>)", + case19_a_mojo, + case19_b_mojo, + case19_a_py, + case19_b_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 20: Equal near zero + var case20_a_mojo = Decimal("0.000000000000000000000000001") + var case20_b_mojo = Decimal("0.000000000000000000000000001") + var case20_a_py = pydecimal.Decimal("0.000000000000000000000000001") + var case20_b_py = pydecimal.Decimal("0.000000000000000000000000001") + run_comparison_benchmark( + "Equal near zero (==)", + case20_a_mojo, + case20_b_mojo, + case20_a_py, + case20_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 21: Common financial values + var case21_a_mojo = Decimal("19.99") + var case21_b_mojo = Decimal("20.00") + var case21_a_py = pydecimal.Decimal("19.99") + var case21_b_py = pydecimal.Decimal("20.00") + run_comparison_benchmark( + "Common financial values (<)", + case21_a_mojo, + case21_b_mojo, + case21_a_py, + case21_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 22: Different sign zeros + var case22_a_mojo = Decimal("0") + var case22_b_mojo = Decimal("-0") + var case22_a_py = pydecimal.Decimal("0") + var case22_b_py = pydecimal.Decimal("-0") + run_comparison_benchmark( + "Different sign zeros (==)", + case22_a_mojo, + case22_b_mojo, + case22_a_py, + case22_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 23: Repeated digits comparison + var case23_a_mojo = Decimal("9.999999999") + var case23_b_mojo = Decimal("10") + var case23_a_py = pydecimal.Decimal("9.999999999") + var case23_b_py = pydecimal.Decimal("10") + run_comparison_benchmark( + "Repeated digits comparison (<)", + case23_a_mojo, + case23_b_mojo, + case23_a_py, + case23_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 24: Scientific notation equivalent + var case24_a_mojo = Decimal("1.23e2") + var case24_b_mojo = Decimal("123") + var case24_a_py = pydecimal.Decimal("1.23e2") + var case24_b_py = pydecimal.Decimal("123") + run_comparison_benchmark( + "Scientific notation equivalent (==)", + case24_a_mojo, + case24_b_mojo, + case24_a_py, + case24_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 25: Same value different format + var case25_a_mojo = Decimal("100.00") + var case25_b_mojo = Decimal("1.0e2") + var case25_a_py = pydecimal.Decimal("100.00") + var case25_b_py = pydecimal.Decimal("1.0e2") + run_comparison_benchmark( + "Same value different format (==)", + case25_a_mojo, + case25_b_mojo, + case25_a_py, + case25_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 26: Almost but not quite equal + var case26_a_mojo = Decimal("0.999999999999999999999") + var case26_b_mojo = Decimal("1.000000000000000000000") + var case26_a_py = pydecimal.Decimal("0.999999999999999999999") + var case26_b_py = pydecimal.Decimal("1.000000000000000000000") + run_comparison_benchmark( + "Almost but not quite equal values (<)", + case26_a_mojo, + case26_b_mojo, + case26_a_py, + case26_b_py, + "<", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 27: Greater than comparison (negative numbers) + var case27_a_mojo = Decimal("-100") + var case27_b_mojo = Decimal("-200") + var case27_a_py = pydecimal.Decimal("-100") + var case27_b_py = pydecimal.Decimal("-200") + run_comparison_benchmark( + "Greater than with negative numbers (>)", + case27_a_mojo, + case27_b_mojo, + case27_a_py, + case27_b_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 28: Equal negative values + var case28_a_mojo = Decimal("-42.5") + var case28_b_mojo = Decimal("-42.50") + var case28_a_py = pydecimal.Decimal("-42.5") + var case28_b_py = pydecimal.Decimal("-42.50") + run_comparison_benchmark( + "Equal negative values (==)", + case28_a_mojo, + case28_b_mojo, + case28_a_py, + case28_b_py, + "==", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 29: Close to zero comparison + var case29_a_mojo = Decimal("0.0000000000000000000000000001") + var case29_b_mojo = Decimal("0") + var case29_a_py = pydecimal.Decimal("0.0000000000000000000000000001") + var case29_b_py = pydecimal.Decimal("0") + run_comparison_benchmark( + "Close to zero comparison (>)", + case29_a_mojo, + case29_b_mojo, + case29_a_py, + case29_b_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Test Case 30: Boundary values + var case30_a_mojo = Decimal.MAX() + var case30_b_mojo = Decimal.MAX() - 1 + var case30_a_py = pydecimal.Decimal(String(Decimal.MAX())) + var case30_b_py = pydecimal.Decimal(String(Decimal.MAX() - 1)) + run_comparison_benchmark( + "Boundary values (>)", + case30_a_mojo, + case30_b_mojo, + case30_a_py, + case30_b_py, + ">", + iterations, + log_file, + speedup_factors, + ) + + # Calculate and report average speedup + var total_speedup = 0.0 + for i in range(speedup_factors.__len__()): + total_speedup += speedup_factors[i] + var avg_speedup = total_speedup / Float64(speedup_factors.__len__()) + + log_print("\n===== Summary =====", log_file) + log_print( + "Total test cases: " + String(speedup_factors.__len__()), log_file + ) + log_print("Average speedup factor: " + String(avg_speedup), log_file) + + # List all speedup factors + log_print("\nIndividual speedup factors:", log_file) + for i in range(len(speedup_factors)): + log_print( + String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)), + log_file, + ) + + # Close the log file + log_file.close() + print("Benchmark completed. Log file closed.") diff --git a/mojoproject.toml b/mojoproject.toml index 056dd3b..6ec5209 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -39,6 +39,8 @@ test_round = "magic run package && magic run mojo test tests/test_round.mojo && test_creation = "magic run package && magic run mojo test tests/test_creation.mojo && magic run delete_package" test_from_float = "magic run package && magic run mojo test tests/test_from_float.mojo && magic run delete_package" test_from_string = "magic run package && magic run mojo test tests/test_from_string.mojo && magic run delete_package" +test_comparison = "magic run package && magic run mojo test tests/test_comparison.mojo && magic run delete_package" + # benches bench = "magic run package && cd benches && magic run mojo bench.mojo && cd .. && magic run delete_package" @@ -49,6 +51,7 @@ bench_sqrt = "magic run package && cd benches && magic run mojo bench_sqrt.mojo bench_round = "magic run package && cd benches && magic run mojo bench_round.mojo && cd .. && magic run delete_package" bench_from_float = "magic run package && cd benches && magic run mojo bench_from_float.mojo && cd .. && magic run delete_package" bench_from_string = "magic run package && cd benches && magic run mojo bench_from_string.mojo && cd .. && magic run delete_package" +bench_comparison = "magic run package && cd benches && magic run mojo bench_comparison.mojo && cd .. && magic run delete_package" # before commit final = "magic run test && magic run bench" diff --git a/src/decimojo/__init__.mojo b/src/decimojo/__init__.mojo index 63d943d..3488cfa 100644 --- a/src/decimojo/__init__.mojo +++ b/src/decimojo/__init__.mojo @@ -29,4 +29,11 @@ from .maths import ( absolute, ) -from .logic import greater, greater_equal, less, less_equal, equal, not_equal +from .comparison import ( + greater, + greater_equal, + less, + less_equal, + equal, + not_equal, +) diff --git a/src/decimojo/comparison.mojo b/src/decimojo/comparison.mojo new file mode 100644 index 0000000..3652ac7 --- /dev/null +++ b/src/decimojo/comparison.mojo @@ -0,0 +1,233 @@ +# ===----------------------------------------------------------------------=== # +# Distributed under the Apache 2.0 License with LLVM Exceptions. +# See LICENSE and the LLVM License for more information. +# https://github.com/forFudan/decimojo/blob/main/LICENSE +# ===----------------------------------------------------------------------=== # +# +# Implements comparison operations for the Decimal type +# +# ===----------------------------------------------------------------------=== # +# +# List of functions in this module: +# +# compare(x: Decimal, y: Decimal) -> Int8: Compares two Decimals +# compare_absolute(x: Decimal, y: Decimal) -> Int8: Compares absolute values of two Decimals +# greater(a: Decimal, b: Decimal) -> Bool: Returns True if a > b +# less(a: Decimal, b: Decimal) -> Bool: Returns True if a < b +# greater_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a >= b +# less_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a <= b +# equal(a: Decimal, b: Decimal) -> Bool: Returns True if a == b +# not_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a != b +# +# List of internal functions in this module: +# +# _compare_abs(a: Decimal, b: Decimal) -> Int: Compares absolute values of two Decimals +# +# ===----------------------------------------------------------------------=== # + +""" +Implements functions for comparison operations on Decimal objects. +""" + +import testing + +from decimojo.decimal import Decimal +import decimojo.utility + + +fn compare(x: Decimal, y: Decimal) -> Int8: + """ + Compares the values of two Decimal numbers and returns the result. + + Args: + x: First Decimal value. + y: Second Decimal value. + + Returns: + Terinary value indicating the comparison result: + (1) 1 if x > y. + (2) 0 if x = y. + (3) -1 if x < y. + """ + + # If both are zero, they are equal regardless of scale or sign + if x.is_zero() and y.is_zero(): + return 0 + + # If x is zero, it is less than any non-zero number + elif x.is_zero(): + return 1 if y.is_negative() else -1 + + # If y is zero, it is less than any non-zero number + elif y.is_zero(): + return -1 if x.is_negative() else 1 + + # If signs differ, the positive one is greater + elif x.is_negative() != y.is_negative(): + return -1 if x.is_negative() else 1 + + # If they have the same sign, compare the absolute values + elif x.is_negative(): + return -compare_absolute(x, y) + + else: + return compare_absolute(x, y) + + +fn compare_absolute(x: Decimal, y: Decimal) -> Int8: + """ + Compares the absolute values of two Decimal numbers and returns the result. + + Args: + x: First Decimal value. + y: Second Decimal value. + + Returns: + Terinary value indicating the comparison result: + (1) 1 if |x| > |y|. + (2) 0 if |x| = |y|. + (3) -1 if |x| < |y|. + """ + + var x_coef: UInt128 = x.coefficient() + var y_coef: UInt128 = y.coefficient() + var x_scale: Int = x.scale() + var y_scale: Int = y.scale() + + # CASE: The scales are the same + # Compare the coefficients directly + if x_scale == y_scale and x_coef == y_coef: + return 0 + if x_scale == y_scale: + return (Int8(x_coef > y_coef)) - (Int8(x_coef < y_coef)) + + # CASE: The scales are different + # Compare the integral part first + # If the integral part is the same, compare the fractional part + else: + # Early return if integer parts have different lengths + # Get number of integer digits + var x_int_digits = decimojo.utility.number_of_digits(x_coef) - x_scale + var y_int_digits = decimojo.utility.number_of_digits(y_coef) - y_scale + if x_int_digits > y_int_digits: + return 1 + if x_int_digits < y_int_digits: + return -1 + + # If interger parts have the same length, compare the integer parts + var x_scale_power = UInt128(10) ** (x_scale) + var y_scale_power = UInt128(10) ** (y_scale) + var x_int = x_coef // x_scale_power + var y_int = y_coef // y_scale_power + + if x_int > y_int: + return 1 + elif x_int < y_int: + return -1 + else: + var x_frac = x_coef % x_scale_power + var y_frac = y_coef % y_scale_power + + # Adjust the fractional part to have the same scale + var scale_diff = x_scale - y_scale + if scale_diff > 0: + y_frac *= UInt128(10) ** scale_diff + else: + x_frac *= UInt128(10) ** (-scale_diff) + + if x_frac > y_frac: + return 1 + elif x_frac < y_frac: + return -1 + else: + return 0 + + +fn greater(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a > b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is greater than b, False otherwise. + """ + + return compare(a, b) == 1 + + +fn less(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a < b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is less than b, False otherwise. + """ + + return compare(a, b) == -1 + + +fn greater_equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a >= b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is greater than or equal to b, False otherwise. + """ + + return compare(a, b) >= 0 + + +fn less_equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a <= b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is less than or equal to b, False otherwise. + """ + + return not greater(a, b) + + +fn equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a == b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a equals b, False otherwise. + """ + + return compare(a, b) == 0 + + +fn not_equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a != b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is not equal to b, False otherwise. + """ + + return compare(a, b) != 0 diff --git a/src/decimojo/decimal.mojo b/src/decimojo/decimal.mojo index 1dc25b9..033c72b 100644 --- a/src/decimojo/decimal.mojo +++ b/src/decimojo/decimal.mojo @@ -18,7 +18,7 @@ # - Output dunders, type-transfer dunders, and other type-transfer methods # - Basic unary arithmetic operation dunders # - Basic binary arithmetic operation dunders -# - Basic binary logic operation dunders +# - Basic comparison operation dunders # - Other dunders that implements traits # - Mathematical methods that do not implement a trait (not a dunder) # - Other methods @@ -41,7 +41,7 @@ Implements basic object methods for working with decimal numbers. from memory import UnsafePointer -import decimojo.logic +import decimojo.comparison import decimojo.maths from decimojo.rounding_mode import RoundingMode import decimojo.utility @@ -925,24 +925,18 @@ struct Decimal( Returns: The absolute value of this Decimal. """ - var result = Decimal.from_words( - self.low, self.mid, self.high, self.flags - ) - result.flags &= ~Self.SIGN_MASK # Clear sign bit - return result + return decimojo.maths.absolute(self) fn __neg__(self) -> Self: - """Unary negation operator.""" - # Special case for negative zero - if self.is_zero(): - return Decimal.ZERO() + """ + Returns the negation of this Decimal. - var result = Decimal.from_words( - self.low, self.mid, self.high, self.flags - ) - result.flags ^= Self.SIGN_MASK # Flip sign bit - return result + Returns: + The negation of this Decimal. + """ + + return decimojo.maths.negative(self) # ===------------------------------------------------------------------=== # # Basic binary arithmetic operation dunders @@ -1075,7 +1069,7 @@ struct Decimal( return decimal.power(self, Decimal(exponent)) # ===------------------------------------------------------------------=== # - # Basic binary logic operation dunders + # Basic binary comparison operation dunders # __gt__, __ge__, __lt__, __le__, __eq__, __ne__ # ===------------------------------------------------------------------=== # @@ -1089,31 +1083,31 @@ struct Decimal( Returns: True if self is greater than other, False otherwise. """ - return decimojo.logic.greater(self, other) + return decimojo.comparison.greater(self, other) - fn __ge__(self, other: Decimal) -> Bool: + fn __lt__(self, other: Decimal) -> Bool: """ - Greater than or equal comparison operator. + Less than comparison operator. Args: other: The Decimal to compare with. Returns: - True if self is greater than or equal to other, False otherwise. + True if self is less than other, False otherwise. """ - return decimojo.logic.greater_equal(self, other) + return decimojo.comparison.less(self, other) - fn __lt__(self, other: Decimal) -> Bool: + fn __ge__(self, other: Decimal) -> Bool: """ - Less than comparison operator. + Greater than or equal comparison operator. Args: other: The Decimal to compare with. Returns: - True if self is less than other, False otherwise. + True if self is greater than or equal to other, False otherwise. """ - return decimojo.logic.less(self, other) + return decimojo.comparison.greater_equal(self, other) fn __le__(self, other: Decimal) -> Bool: """ @@ -1125,7 +1119,7 @@ struct Decimal( Returns: True if self is less than or equal to other, False otherwise. """ - return decimojo.logic.less_equal(self, other) + return decimojo.comparison.less_equal(self, other) fn __eq__(self, other: Decimal) -> Bool: """ @@ -1137,7 +1131,7 @@ struct Decimal( Returns: True if self is equal to other, False otherwise. """ - return decimojo.logic.equal(self, other) + return decimojo.comparison.equal(self, other) fn __ne__(self, other: Decimal) -> Bool: """ @@ -1149,7 +1143,7 @@ struct Decimal( Returns: True if self is not equal to other, False otherwise. """ - return decimojo.logic.not_equal(self, other) + return decimojo.comparison.not_equal(self, other) # ===------------------------------------------------------------------=== # # Other dunders that implements traits diff --git a/src/decimojo/logic.mojo b/src/decimojo/logic.mojo deleted file mode 100644 index 7c4cf46..0000000 --- a/src/decimojo/logic.mojo +++ /dev/null @@ -1,231 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE -# ===----------------------------------------------------------------------=== # -# -# Implements logic operations for the Decimal type -# -# ===----------------------------------------------------------------------=== # -# -# List of functions in this module: -# -# greater(a: Decimal, b: Decimal) -> Bool: Returns True if a > b -# greater_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a >= b -# less(a: Decimal, b: Decimal) -> Bool: Returns True if a < b -# less_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a <= b -# equal(a: Decimal, b: Decimal) -> Bool: Returns True if a == b -# not_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a != b -# -# List of internal functions in this module: -# -# _compare_abs(a: Decimal, b: Decimal) -> Int: Compares absolute values of two Decimals -# -# ===----------------------------------------------------------------------=== # - -""" -Implements functions for comparison operations on Decimal objects. -""" - -from decimojo.decimal import Decimal -import decimojo.utility - - -fn greater(a: Decimal, b: Decimal) -> Bool: - """ - Returns True if a > b. - - Args: - a: First Decimal value. - b: Second Decimal value. - - Returns: - True if a is greater than b, False otherwise. - """ - # Handle special case where either or both are zero - if a.is_zero() and b.is_zero(): - return False # Zero equals zero - if a.is_zero(): - return b.is_negative() # a=0 > b only if b is negative - if b.is_zero(): - return ( - not a.is_negative() and not a.is_zero() - ) # a > b=0 only if a is positive and non-zero - - # If they have different signs, positive is always greater - if a.is_negative() != b.is_negative(): - return not a.is_negative() # a > b if a is positive and b is negative - - # Now we know they have the same sign - # Compare absolute values, considering the sign - var compare_result = _compare_abs(a, b) - - if a.is_negative(): - # For negative numbers, the one with smaller absolute value is greater - return compare_result < 0 - else: - # For positive numbers, the one with larger absolute value is greater - return compare_result > 0 - - -fn greater_equal(a: Decimal, b: Decimal) -> Bool: - """ - Returns True if a >= b. - - Args: - a: First Decimal value. - b: Second Decimal value. - - Returns: - True if a is greater than or equal to b, False otherwise. - """ - # Handle special case where either or both are zero - if a.is_zero() and b.is_zero(): - return True # Zero equals zero - if a.is_zero(): - return ( - b.is_zero() or b.is_negative() - ) # a=0 >= b only if b is zero or negative - if b.is_zero(): - return ( - a.is_negative() == False - ) # a >= b=0 only if a is positive or zero - - # If they have different signs, positive is always greater - if a.is_negative() != b.is_negative(): - return not a.is_negative() # a >= b if a is positive and b is negative - - # Now we know they have the same sign - # Compare absolute values, considering the sign - var compare_result = _compare_abs(a, b) - - if a.is_negative(): - # For negative numbers, the one with smaller or equal absolute value is greater or equal - return compare_result <= 0 - else: - # For positive numbers, the one with larger or equal absolute value is greater or equal - return compare_result >= 0 - - -fn less(a: Decimal, b: Decimal) -> Bool: - """ - Returns True if a < b. - - Args: - a: First Decimal value. - b: Second Decimal value. - - Returns: - True if a is less than b, False otherwise. - """ - # We can use the greater function with arguments reversed - return greater(b, a) - - -fn less_equal(a: Decimal, b: Decimal) -> Bool: - """ - Returns True if a <= b. - - Args: - a: First Decimal value. - b: Second Decimal value. - - Returns: - True if a is less than or equal to b, False otherwise. - """ - # We can use the greater_equal function with arguments reversed - return greater_equal(b, a) - - -fn equal(a: Decimal, b: Decimal) -> Bool: - """ - Returns True if a == b. - - Args: - a: First Decimal value. - b: Second Decimal value. - - Returns: - True if a equals b, False otherwise. - """ - # If both are zero, they are equal regardless of scale or sign - if a.is_zero() and b.is_zero(): - return True - - # If signs differ, they're not equal - if a.is_negative() != b.is_negative(): - return False - - # Compare absolute values - return _compare_abs(a, b) == 0 - - -fn not_equal(a: Decimal, b: Decimal) -> Bool: - """ - Returns True if a != b. - - Args: - a: First Decimal value. - b: Second Decimal value. - - Returns: - True if a is not equal to b, False otherwise. - """ - # Simply negate the equal function - return not equal(a, b) - - -fn _compare_abs(a: Decimal, b: Decimal) -> Int: - """ - Internal helper to compare absolute values of two Decimal numbers. - - Returns: - - Positive value if |a| > |b| - - Zero if |a| = |b| - - Negative value if |a| < |b| - - raises: - Error: Calling `scale_up()` failed. - """ - # Normalize scales by scaling up the one with smaller scale - var scale_a = a.scale() - var scale_b = b.scale() - - # Create temporary copies that we will scale - var a_copy = a - var b_copy = b - - # Scale up the decimal with smaller scale to match the other - # TODO: Treat this error properly - if scale_a < scale_b: - try: - a_copy = decimojo.utility.scale_up(a, scale_b - scale_a) - except: - a_copy = a - elif scale_b < scale_a: - try: - b_copy = decimojo.utility.scale_up(b, scale_a - scale_b) - except: - b_copy = b - - # Now both have the same scale, compare integer components - # Compare high parts first (most significant) - if a_copy.high > b_copy.high: - return 1 - if a_copy.high < b_copy.high: - return -1 - - # High parts equal, compare mid parts - if a_copy.mid > b_copy.mid: - return 1 - if a_copy.mid < b_copy.mid: - return -1 - - # Mid parts equal, compare low parts (least significant) - if a_copy.low > b_copy.low: - return 1 - if a_copy.low < b_copy.low: - return -1 - - # All components are equal - return 0 diff --git a/src/decimojo/maths/__init__.mojo b/src/decimojo/maths/__init__.mojo index db7309d..f44a61a 100644 --- a/src/decimojo/maths/__init__.mojo +++ b/src/decimojo/maths/__init__.mojo @@ -25,12 +25,13 @@ # lcm(a: Decimal, b: Decimal): Returns least common multiple of a and b # ===----------------------------------------------------------------------=== # -from .basic import ( +from .arithmetics import ( add, subtract, + negative, + absolute, multiply, true_divide, ) from .exp import power, sqrt from .rounding import round -from .misc import absolute diff --git a/src/decimojo/maths/basic.mojo b/src/decimojo/maths/arithmetics.mojo similarity index 98% rename from src/decimojo/maths/basic.mojo rename to src/decimojo/maths/arithmetics.mojo index b8b92cf..49bae35 100644 --- a/src/decimojo/maths/basic.mojo +++ b/src/decimojo/maths/arithmetics.mojo @@ -286,6 +286,44 @@ fn subtract(x1: Decimal, x2: Decimal) raises -> Decimal: raise Error("Error in `subtract()`; ", e) +fn negative(x: Decimal) -> Decimal: + """ + Returns the negative of a Decimal number. + + Args: + x: The Decimal value to compute the negative of. + + Returns: + A new Decimal containing the negative of x. + """ + + var result = x + + if x.is_zero(): + # Set the sign bit to 0 and keep the scale bits + result.flags &= ~Decimal.SIGN_MASK + + else: + result.flags ^= Decimal.SIGN_MASK # Flip sign bit + + return result + + +fn absolute(x: Decimal) -> Decimal: + """ + Returns the absolute value of a Decimal number. + + Args: + x: The Decimal value to compute the absolute value of. + + Returns: + A new Decimal containing the absolute value of x. + """ + if x.is_negative(): + return -x + return x + + fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal: """ Multiplies two Decimal values and returns a new Decimal containing the product. diff --git a/src/decimojo/maths/hyper.mojo b/src/decimojo/maths/hyper.mojo deleted file mode 100644 index a3474c3..0000000 --- a/src/decimojo/maths/hyper.mojo +++ /dev/null @@ -1,9 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE -# ===----------------------------------------------------------------------=== # -# -# Implements hyperbolic functions for the Decimal type -# -# ===----------------------------------------------------------------------=== # diff --git a/src/decimojo/maths/misc.mojo b/src/decimojo/maths/misc.mojo deleted file mode 100644 index 5a6cea9..0000000 --- a/src/decimojo/maths/misc.mojo +++ /dev/null @@ -1,24 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE -# ===----------------------------------------------------------------------=== # -# -# Implements miscellaneous mathematical functions for the Decimal type -# -# ===----------------------------------------------------------------------=== # - - -fn absolute(x: Decimal) raises -> Decimal: - """ - Returns the absolute value of a Decimal number. - - Args: - x: The Decimal value to compute the absolute value of. - - Returns: - A new Decimal containing the absolute value of x. - """ - if x.is_negative(): - return -x - return x diff --git a/src/decimojo/maths/trig.mojo b/src/decimojo/maths/trig.mojo deleted file mode 100644 index 7fb4e27..0000000 --- a/src/decimojo/maths/trig.mojo +++ /dev/null @@ -1,9 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Distributed under the Apache 2.0 License with LLVM Exceptions. -# See LICENSE and the LLVM License for more information. -# https://github.com/forFudan/decimojo/blob/main/LICENSE -# ===----------------------------------------------------------------------=== # -# -# Implements trigonometric functions for the Decimal type -# -# ===----------------------------------------------------------------------=== # diff --git a/tests/test_logic.mojo b/tests/test_comparison.mojo similarity index 99% rename from tests/test_logic.mojo rename to tests/test_comparison.mojo index 44486ae..2512209 100644 --- a/tests/test_logic.mojo +++ b/tests/test_comparison.mojo @@ -3,7 +3,7 @@ Test Decimal logic operations for comparison, including basic comparisons, edge cases, special handling for zero values, and operator overloads. """ from decimojo.prelude import dm, Decimal, RoundingMode -from decimojo.logic import ( +from decimojo.comparison import ( greater, greater_equal, less,