diff --git a/src/decimojo/bigdecimal/arithmetics.mojo b/src/decimojo/bigdecimal/arithmetics.mojo index 894a923..d601c12 100644 --- a/src/decimojo/bigdecimal/arithmetics.mojo +++ b/src/decimojo/bigdecimal/arithmetics.mojo @@ -21,7 +21,6 @@ Implements functions for mathematical operations on BigDecimal objects. import time import testing -from decimojo.decimal.decimal import Decimal from decimojo.rounding_mode import RoundingMode import decimojo.utility diff --git a/src/decimojo/bigdecimal/exponential.mojo b/src/decimojo/bigdecimal/exponential.mojo index 24a4d4d..dbb2b8a 100644 --- a/src/decimojo/bigdecimal/exponential.mojo +++ b/src/decimojo/bigdecimal/exponential.mojo @@ -192,7 +192,7 @@ fn integer_power( fill_zeros_to_precision=False, ) - exp_value.floor_divide_inplace_by_2() + decimojo.biguint.arithmetics.floor_divide_inplace_by_2(exp_value) # For negative exponents, compute reciprocal if is_negative_exponent: diff --git a/src/decimojo/biguint/arithmetics.mojo b/src/decimojo/biguint/arithmetics.mojo index 84d38f5..ed875f0 100644 --- a/src/decimojo/biguint/arithmetics.mojo +++ b/src/decimojo/biguint/arithmetics.mojo @@ -66,7 +66,8 @@ from decimojo.rounding_mode import RoundingMode # ceil_modulo(x1: BigUInt, x2: BigUInt) -> BigUInt # divmod(x1: BigUInt, x2: BigUInt) -> Tuple[BigUInt, BigUInt] # -# normalize_carries(x: BigUInt) -> None +# normalize_carries_lt_2_bases(x: BigUInt) -> None +# normalize_carries_lt4_bases(x: BigUInt) -> None # power_of_10(n: Int) -> BigUInt # ===----------------------------------------------------------------------=== # @@ -245,7 +246,7 @@ fn add_simd(x: BigUInt, y: BigUInt) -> BigUInt: ) var result = BigUInt(words=words^) - normalize_carries(result) + normalize_carries_lt_2_bases(result) return result^ @@ -397,7 +398,7 @@ fn add_inplace(mut x: BigUInt, y: BigUInt) -> None: vectorize[vector_add, BigUInt.VECTOR_WIDTH](len(y.words)) # Normalize carries after addition - normalize_carries(x) + normalize_carries_lt_2_bases(x) return @@ -1010,17 +1011,17 @@ fn multiply_karatsuba( fn multiply_inplace_by_uint32(mut x: BigUInt, y: UInt32): - """Multiplies a BigUInt by an UInt32 word in-place. + """Multiplies in-place a BigUInt by a UInt32 value. Args: x: The BigUInt value to multiply. y: The single word to multiply by. """ - if y == 0: - x.words = List[UInt32](0) - return - - if y == 1: + # Short circuit cases when y is between 0 and 4 + # See `multiply_inplace_by_uint32_le_4()` for details + # TODO: Check the performance of `y <= 4` + if y <= 2: + multiply_inplace_by_uint32_le_4(x, y) return var y_as_uint64 = UInt64(y) @@ -1036,6 +1037,78 @@ fn multiply_inplace_by_uint32(mut x: BigUInt, y: UInt32): x.words.append(UInt32(carry)) +@always_inline +fn multiply_inplace_by_uint32_le_4(mut x: BigUInt, y: UInt32): + """Multiplies in-place a BigUInt by a UInt32 value which is between 0 and 4. + + Args: + x: The BigUInt value to multiply. + y: The single word to multiply by. It must be between 0 and 4. + + Notes: + + This function will be used in the `multiply_inplace_by_uint32()` function. + It is optimized for the case where y is between 0 and 4. + + When a valid word times 2, 3, or 4, the result is no larger than 4*10^9, + which is less than 2^32-1. This means that we do not need to use UInt64 to + store the product but use UInt32 directly. We can first use SIMD to do + word-by-word multiplication, and then handle the carries. + + This function works the best when y is 0, 1, or 2. For y = 3 or 4, the + normalization of carries is more expensive and may not compensate for the + extra loop overhead. + """ + + # y is 0, x becomes 1 + if y == 0: + x.words = List[UInt32](0) + return + + # y is 1, x stays the same + if y == 1: + return + + # y is 2, we can just shift the digits of each word to the left by 1 + @parameter + fn vector_multiply_by_2[simd_width: Int](i: Int): + """Shifts the digits of each word to the left by 1.""" + x.words.data.store[width=simd_width]( + i, x.words.data.load[width=simd_width](i) << 1 + ) + + if y == 2: + vectorize[vector_multiply_by_2, BigUInt.VECTOR_WIDTH](len(x.words)) + normalize_carries_lt_2_bases(x) + return + + # y is 3, we can just multiply the digits of each word by 3 + @parameter + fn vector_multiply_by_3[simd_width: Int](i: Int): + """Multiplies the digits of each word by 3.""" + x.words.data.store[width=simd_width]( + i, x.words.data.load[width=simd_width](i) * 3 + ) + + if y == 3: + vectorize[vector_multiply_by_3, BigUInt.VECTOR_WIDTH](len(x.words)) + normalize_carries_lt_4_bases(x) + return + + # y is 4, we can just shift the digits of each word to the left by 2 + @parameter + fn vector_multiply_by_4[simd_width: Int](i: Int): + """Shifts the digits of each word to the left by 2.""" + x.words.data.store[width=simd_width]( + i, x.words.data.load[width=simd_width](i) << 2 + ) + + if y == 4: + vectorize[vector_multiply_by_4, BigUInt.VECTOR_WIDTH](len(x.words)) + normalize_carries_lt_4_bases(x) + return + + fn multiply_by_power_of_ten(x: BigUInt, n: Int) -> BigUInt: """Multiplies a BigUInt by 10^n if n > 0, otherwise doing nothing. @@ -1308,22 +1381,20 @@ fn floor_divide(x1: BigUInt, x2: BigUInt) raises -> BigUInt: len(x2.words) <= CUTOFF_BURNIKEL_ZIEGLER ): # I will normalize the divisor to improve quotient estimation - var normalization_factor: Int # Number of digits to shift + var ndigits_to_shift: Int # Number of digits to shift # Calculate normalization factor to make leading digit of divisor # as large as possible - normalization_factor = calculate_normalization_factor(x2.words[-1]) + ndigits_to_shift = calculate_number_of_shifted_digits_for_normalization( + x2.words[-1] + ) - if normalization_factor == 0: + if ndigits_to_shift == 0: # No normalization needed, just use the general division algorithm return floor_divide_school(x1, x2) else: # Normalize the divisor and dividend - var normalized_x1 = multiply_by_power_of_ten( - x1, normalization_factor - ) - var normalized_x2 = multiply_by_power_of_ten( - x2, normalization_factor - ) + var normalized_x1 = multiply_by_power_of_ten(x1, ndigits_to_shift) + var normalized_x2 = multiply_by_power_of_ten(x2, ndigits_to_shift) return floor_divide_school(normalized_x1, normalized_x2) # CASE: division of very, very large numbers @@ -1557,17 +1628,18 @@ fn floor_divide_inplace_by_2(mut x: BigUInt) -> None: if x.is_zero(): return - var carry: UInt32 = 0 - # Process from most significant to least significant word + var base: UInt32 = BigUInt.BASE + var is_carry: Bool = False for ith in range(len(x.words) - 1, -1, -1): - x.words[ith] += carry - carry = BigUInt.BASE if (x.words[ith] & 1) else 0 + if is_carry: + x.words[ith] += base + if x.words[ith] & 1: + is_carry = True + else: + is_carry = False x.words[ith] >>= 1 - - # Remove leading zeros - while len(x.words) > 1 and x.words[len(x.words) - 1] == 0: - x.words.resize(len(x.words) - 1, UInt32(0)) + x.remove_leading_empty_words() fn floor_divide_by_power_of_ten(x: BigUInt, n: Int) raises -> BigUInt: @@ -1688,18 +1760,16 @@ fn floor_divide_burnikel_ziegler( var normalized_b = b var normalized_a = a - var normalization_factor: Int + var ndigits_to_shift: Int if normalized_b.words[-1] == 0: normalized_b.remove_leading_empty_words() if normalized_b.words[-1] < 500_000_000: - normalization_factor = ( - decimojo.biguint.arithmetics.calculate_normalization_factor( - normalized_b.words[-1] - ) + ndigits_to_shift = decimojo.biguint.arithmetics.calculate_number_of_shifted_digits_for_normalization( + normalized_b.words[-1] ) else: - normalization_factor = 0 + ndigits_to_shift = 0 # The targeted number of blocks should be the smallest 2^k such that # 2^k >= number of words in normalized_b ceil divided by BLOCK_SIZE_OF_WORDS. @@ -1714,7 +1784,7 @@ fn floor_divide_burnikel_ziegler( var n_digits_to_scale_up = ( n - len(normalized_b.words) - ) * 9 + normalization_factor + ) * 9 + ndigits_to_shift decimojo.biguint.arithmetics.multiply_inplace_by_power_of_ten( normalized_b, n_digits_to_scale_up @@ -1724,8 +1794,15 @@ fn floor_divide_burnikel_ziegler( ) # The normalized_b is now 9 digits, but may still be smaller than 500_000_000. - var gap_ratio = BigUInt.BASE // normalized_b.words[-1] - if gap_ratio > 2: + var gap_ratio: UInt32 + if normalized_b.words[-1] >= 500_000_000: # Already normalized + gap_ratio = 1 + elif normalized_b.words[-1] >= 125_000_000: # 2x is enough + gap_ratio = 2 + else: # The most significant word is in [100_000_000, 125_000_000) + gap_ratio = BigUInt.BASE_MAX // normalized_b.words[-1] + + if gap_ratio >= 2: decimojo.biguint.arithmetics.multiply_inplace_by_uint32( normalized_b, gap_ratio ) @@ -1775,19 +1852,23 @@ fn floor_divide_three_by_two( n: Int, cut_off: Int, ) raises -> Tuple[BigUInt, BigUInt]: - """Divides a 3-word number by a 2-word number. + """Divides a 3-part number by a 2-part number. Args: - a2: The most significant word of the dividend. - a1: The middle word of the dividend. - a0: The least significant word of the dividend. - b1: The most significant word of the divisor. - b0: The least significant word of the divisor. - n: The number of words in the divisor. - cut_off: The minimum number of words for the recursive division. + a2: The most significant part of the dividend. + a1: The middle part of the dividend. + a0: The least significant part of the dividend. + b1: The most significant part of the divisor. + b0: The least significant part of the divisor. + n: The number of part in the divisor. + cut_off: The minimum number of part for the recursive division. Returns: A tuple containing the quotient and the remainder as BigUInt. + + Notes: + + a is a BigUInt with 3n words and b is a BigUInt with 2n words. """ var a2a1: BigUInt @@ -2127,7 +2208,7 @@ fn divmod(x1: BigUInt, x2: BigUInt) raises -> Tuple[BigUInt, BigUInt]: # ===----------------------------------------------------------------------=== # -fn normalize_carries(mut x: BigUInt): +fn normalize_carries_lt_2_bases(mut x: BigUInt): """Normalizes the values of words into valid range by carrying over. The initial values of the words should be in the range [0, BASE*2). @@ -2163,6 +2244,80 @@ fn normalize_carries(mut x: BigUInt): return +fn normalize_carries_lt_4_bases(mut x: BigUInt): + """Normalizes the values of words into valid range by carrying over. + The initial values of the words should be in the range [0, BASE * 4 - 4]. + + Notes: + + If we multiply a BigUInt numbers word-by-word by 3 or 4, we may end up with + a situation where some words are ge than BASE but le BASE * 4 - 4. + This function normalizes the carries, ensuring that all words are within the + valid range. It modifies the input BigUInt in-place. + """ + + # Yuhao ZHU: + # By construction, the words of x are in the range [0, BASE*4). + # Thus, the carry can only be 0, 1, 2, or 3. + var carry: UInt32 = 0 + for ref word in x.words: + if carry == 0: + if word <= UInt32(999_999_999): + pass # carry = 0 + elif word <= UInt32(1_999_999_999): + word -= UInt32(1_000_000_000) + carry = 1 + elif word <= UInt32(2_999_999_999): + word -= UInt32(2_000_000_000) + carry = 2 + else: # 3_000_000_000 <= word <= 3_999_999_996 + word -= UInt32(3_000_000_000) + carry = 3 + elif carry == 1: + if word <= UInt32(999_999_998): + word += 1 + carry = 0 + elif word <= UInt32(1_999_999_998): + word = word + 1 - UInt32(1_000_000_000) + carry = 1 + elif word <= UInt32(2_999_999_998): + word = word + 1 - UInt32(2_000_000_000) + carry = 2 + else: # 2_999_999_999 <= word <= 3_999_999_996 + word = word + 1 - UInt32(3_000_000_000) + carry = 3 + elif carry == 2: + if word <= UInt32(999_999_997): + word += 2 + carry = 0 + elif word <= UInt32(1_999_999_997): + word = word + 2 - UInt32(1_000_000_000) + carry = 1 + elif word <= UInt32(2_999_999_997): + word = word + 2 - UInt32(2_000_000_000) + carry = 2 + else: # 2_999_999_998 <= word <= 3_999_999_996 + word = word + 2 - UInt32(3_000_000_000) + carry = 3 + else: # carry == 3 + if word <= UInt32(999_999_996): + word += 3 + carry = 0 + elif word <= UInt32(1_999_999_996): + word = word + 3 - UInt32(1_000_000_000) + carry = 1 + elif word <= UInt32(2_999_999_996): + word = word + 3 - UInt32(2_000_000_000) + carry = 2 + else: # 2_999_999_997 <= word <= 3_999_999_996 + word = word + 3 - UInt32(3_000_000_000) + carry = 3 + if carry > 0: + # If there is still a carry, we need to add a new word + x.words.append(UInt32(carry)) + return + + fn normalize_borrows(mut x: BigUInt): """Normalizes the values of words into valid range by borrowing. The caller should ensure that the final result is non-negative. @@ -2245,37 +2400,44 @@ fn power_of_10(n: Int) raises -> BigUInt: return result^ -fn calculate_normalization_factor(msw: UInt32) -> Int: - """Calculates the normalization factor based on the most significant word. - The normalized word should be as close to BASE as possible. +@always_inline +fn calculate_number_of_shifted_digits_for_normalization(msw: UInt32) -> Int: + """Calculates the number of digits to shift left for normalization. + + Args: + msw: The most significant word of the number to normalize. + + Returns: + The number of digits to shift left to normalize the number. Notes: This is a helper function for division algorithms. + The normalized word should be as close to BASE as possible. """ if msw < 10_000: if msw < 100: if msw < 10: - normalization_factor = 8 # Shift by 8 digits + ndigits = 8 # Shift by 8 digits else: # 10 <= msw < 100 - normalization_factor = 7 # Shift by 7 digits + ndigits = 7 # Shift by 7 digits else: # 100 <= msw < 10_000 if msw < 1_000: # 100 <= msw < 1_000 - normalization_factor = 6 # Shift by 6 digits + ndigits = 6 # Shift by 6 digits else: # 1_000 <= msw < 10_000: - normalization_factor = 5 # Shift by 5 digits + ndigits = 5 # Shift by 5 digits elif msw < 100_000_000: # 10_000 <= msw < 100_000_000 if msw < 1_000_000: if msw < 100_000: # 10_000 <= msw < 100_000 - normalization_factor = 4 # Shift by 4 digits + ndigits = 4 # Shift by 4 digits else: # 100_000 <= msw < 1_000_000 - normalization_factor = 3 # Shift by 3 digits + ndigits = 3 # Shift by 3 digits else: # 1_000_000 <= msw < 100_000_000 if msw < 10_000_000: # 1_000_000 <= msw < 10_000_000 - normalization_factor = 2 # Shift by 2 digits + ndigits = 2 # Shift by 2 digits else: # 10_000_000 <= msw < 100_000_000 - normalization_factor = 1 # Shift by 1 digit + ndigits = 1 # Shift by 1 digit else: # 100_000_000 <= msw < 1_000_000_000 - normalization_factor = 0 # No shift needed + ndigits = 0 # No shift needed - return normalization_factor + return ndigits