Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 1 addition & 149 deletions src/decimojo/biguint/arithmetics.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -604,154 +604,6 @@ fn divmod(x1: BigUInt, x2: BigUInt) raises -> Tuple[BigUInt, BigUInt]:
# ===----------------------------------------------------------------------=== #


# TODO: The subtraction can be underflowed. Use signed integers for the subtraction
fn multiply_toom_cook_3(x1: BigUInt, x2: BigUInt) raises -> BigUInt:
"""Implements Toom-Cook 3-way multiplication algorithm.

Args:
x1: First operand.
x2: Second operand.

Returns:
Product of x1 and x2.

Notes:

This algorithm splits each number into 3 parts and performs 5 multiplications
instead of 9, achieving O(n^log₃5) ≈ O(n^1.465) complexity.
"""
# Special cases
if x1.is_zero() or x2.is_zero():
return BigUInt()
if x1.is_one():
return x2
if x2.is_one():
return x1

# # Basic multiplication is faster for small numbers
# if len(x1.words) < 10 or len(x2.words) < 10:
# return multiply(x1, x2)

# Determine size for splitting
var max_len = max(len(x1.words), len(x2.words))
var k = (max_len + 2) // 3 # Split into thirds

# Split the numbers into three parts each: a = a₂·β² + a₁·β + a₀
var a0_words = List[UInt32]()
var a1_words = List[UInt32]()
var a2_words = List[UInt32]()
var b0_words = List[UInt32]()
var b1_words = List[UInt32]()
var b2_words = List[UInt32]()

# Extract parts from x1
for i in range(min(k, len(x1.words))):
a0_words.append(x1.words[i])
for i in range(k, min(2 * k, len(x1.words))):
a1_words.append(x1.words[i])
for i in range(2 * k, len(x1.words)):
a2_words.append(x1.words[i])

# Extract parts from x2
for i in range(min(k, len(x2.words))):
b0_words.append(x2.words[i])
for i in range(k, min(2 * k, len(x2.words))):
b1_words.append(x2.words[i])
for i in range(2 * k, len(x2.words)):
b2_words.append(x2.words[i])

a0 = BigUInt.from_list(a0_words^)
a1 = BigUInt.from_list(a1_words^)
a2 = BigUInt.from_list(a2_words^)
b0 = BigUInt.from_list(b0_words^)
b1 = BigUInt.from_list(b1_words^)
b2 = BigUInt.from_list(b2_words^)

# Remove trailing zeros
a0.remove_leading_empty_words()
a1.remove_leading_empty_words()
a2.remove_leading_empty_words()
b0.remove_leading_empty_words()
b1.remove_leading_empty_words()
b2.remove_leading_empty_words()

print("DEBUG: a0 =", a0)
print("DEBUG: a1 =", a1)
print("DEBUG: a2 =", a2)
print("DEBUG: b0 =", b0)
print("DEBUG: b1 =", b1)
print("DEBUG: b2 =", b2)

# Evaluate at points 0, 1, -1, 2, ∞
# p₀ = a₀
var p0_a = a0
# p₁ = a₀ + a₁ + a₂
var p1_a = a0 + a1 + a2
# p₂ = a₀ - a₁ + a₂
var p2_a = a0 + a2 - a1
# p₃ = a₀ + 2a₁ + 4a₂
var p3_a = a0 + a1 * BigUInt(UInt32(2)) + a2 * BigUInt(UInt32(4))
# p₄ = a₂
var p4_a = a2

# Same for b
var p0_b = b0
var p1_b = add(add(b0, b1), b2)
var p2_b = add(subtract(b0, b1), b2)
var b1_times2 = add(b1, b1)
var b2_times4 = add(add(b2, b2), add(b2, b2))
var p3_b = add(add(b0, b1_times2), b2_times4)
var p4_b = b2

# Perform pointwise multiplication
var r0 = multiply(p0_a, p0_b) # at 0
var r1 = multiply(p1_a, p1_b) # at 1
var r2 = multiply(p2_a, p2_b) # at -1
var r3 = multiply(p3_a, p3_b) # at 2
var r4 = multiply(p4_a, p4_b) # at ∞

# Interpolate to get coefficients of the result
# c₀ = r₀
var c0 = r0

# c₄ = r₄
var c4 = r4

# TODO: The subtraction can be underflowed. Use signed integers for the subtraction
# c₃ = (r₃ - r₁)/3 - (r₄ - r₂)/2 + r₄·5/6
var t1 = (r3 - r1) // BigUInt(UInt32(3))
var t2 = (r4 - r2) // BigUInt(UInt32(2))
var t3 = r4 * BigUInt(UInt32(5)) // BigUInt(UInt32(6))
var c3 = t1 + t3 - t2

# c₂ = (r₂ - r₀)/2 - r₄
var c2 = (r2 - r0) // BigUInt(UInt32(2)) - r4

# c₁ = r₁ - r₀ - c₃ - c₄ - c₂
var c1 = r1 - r0 - c3 - c4 - c2

# Combine the coefficients to get the result
var result = c0

# c₁ * β
var c1_shifted = shift_words_left(c1, k)
result = result + c1_shifted

# c₂ * β²
var c2_shifted = shift_words_left(c2, 2 * k)
result = result + c2_shifted

# c₃ * β³
var c3_shifted = shift_words_left(c3, 3 * k)
result = result + c3_shifted

# c₄ * β⁴
var c4_shifted = shift_words_left(c4, 4 * k)
result = result + c4_shifted

return result


fn scale_up_by_power_of_10(x: BigUInt, n: Int) raises -> BigUInt:
"""Multiplies a BigUInt by 10^n (n>=0).

Expand Down Expand Up @@ -1154,7 +1006,7 @@ fn estimate_quotient(


fn shift_words_left(num: BigUInt, positions: Int) -> BigUInt:
"""Shifts a BigUInt left by adding leading zeros.
"""Shifts a BigUInt left by adding trailing zeros.
Equivalent to multiplying by 10^(9*positions)."""
if num.is_zero():
return BigUInt()
Expand Down
7 changes: 7 additions & 0 deletions src/decimojo/biguint/biguint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,10 @@ struct BigUInt(Absable, IntableRaising, Writable):
1,
)
return result^

@always_inline
fn shift_words_left(self, position: Int) -> Self:
"""Shifts the words of the BigUInt to the left by `position` bits.
See `arithmetics.shift_words_left()` for more information.
"""
return decimojo.biguint.arithmetics.shift_words_left(self, position)