Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:

- name: Run tests
run: |
pixi run mojo test tests
bash ./tests/test_all.sh

- name: Install pre-commit
run: |
Expand Down
10 changes: 5 additions & 5 deletions pixi.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[project]
[workspace]
authors = ["ZHU Yuhao 朱宇浩 <[email protected]>"]
channels = ["https://conda.modular.com/max-nightly", "https://conda.modular.com/max", "https://repo.prefix.dev/modular-community", "conda-forge"]
description = "An arbitrary-precision decimal and integer mathematics library for Mojo"
Expand All @@ -9,7 +9,7 @@ readme = "README.md"
version = "0.6.0"

[dependencies]
mojo = "==0.25.6"
mojo = "==0.25.7.0"

[tasks]
# format the code
Expand All @@ -26,9 +26,9 @@ c = "clear && pixi run clean"
clean = "rm tests/decimojo.mojopkg && rm benches/decimojo.mojopkg && rm tests/tomlmojo.mojopkg"

# tests (use the mojo testing tool)
b = "pixi run t big"
t = "clear && pixi run package && pixi run mojo test tests -D ASSERT=all --filter"
test = "pixi run package && pixi run mojo test tests -D ASSERT=all --filter"
b = "clear && pixi run package && bash ./tests/test_big.sh"
t = "clear && pixi run package && bash ./tests/test_all.sh"
test = "pixi run package && bash ./tests/test_all.sh"

# benches
bdec = "clear && pixi run package && cd benches/bigdecimal && pixi run mojo run -I ../ bench.mojo && cd ../.. && pixi run clean"
Expand Down
2 changes: 1 addition & 1 deletion src/decimojo/bigdecimal/bigdecimal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ struct BigDecimal(
return Self(coefficient=BigUInt(words^), scale=0, sign=sign)

@staticmethod
fn from_uint(value: Int) -> Self:
fn from_uint(value: UInt) -> Self:
"""Creates a BigDecimal from an unsigned integer."""
return Self(coefficient=BigUInt.from_uint(value), scale=0, sign=False)

Expand Down
2 changes: 1 addition & 1 deletion src/decimojo/bigdecimal/trigonometric.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ fn arctan(x: BigDecimal, precision: Int) raises -> BigDecimal:
# arctan(x) = 2 * arctan(x / (1 + sqrt(1 + x²)))
# This is to ensure convergence of the Taylor series.
# print("Using identity for arctan with |x| <= 2")
print(bdec_1 + x * x)
# print(bdec_1 + x * x)
var sqrt_term = (bdec_1 + x * x).sqrt(precision=working_precision)
var x_divided = x.true_divide(
bdec_1 + sqrt_term, precision=working_precision
Expand Down
2 changes: 1 addition & 1 deletion src/decimojo/bigint/bigint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ struct BigInt(
# Other dunders
# ===------------------------------------------------------------------=== #

fn __merge_with__[other_type: __type_of(BigDecimal)](self) -> BigDecimal:
fn __merge_with__[other_type: type_of(BigDecimal)](self) -> BigDecimal:
"Merges this BigInt with a BigDecimal into a BigDecimal."
return BigDecimal(self)

Expand Down
6 changes: 3 additions & 3 deletions src/decimojo/biguint/arithmetics.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,18 @@ fn add_slices_simd(
min(n_words_x_slice, n_words_y_slice)
)

var longer: Pointer[BigUInt, __origin_of(x, y)]
var longer: Pointer[BigUInt, origin_of(x, y)]
var n_words_longer_slice: Int
var n_words_shorter_slice: Int
var longer_start: Int

if n_words_x_slice >= n_words_y_slice:
longer = Pointer[BigUInt, __origin_of(x, y)](to=x)
longer = Pointer[BigUInt, origin_of(x, y)](to=x)
n_words_longer_slice = n_words_x_slice
n_words_shorter_slice = n_words_y_slice
longer_start = bounds_x[0]
else:
longer = Pointer[BigUInt, __origin_of(x, y)](to=y)
longer = Pointer[BigUInt, origin_of(x, y)](to=y)
n_words_longer_slice = n_words_y_slice
n_words_shorter_slice = n_words_x_slice
longer_start = bounds_y[0]
Expand Down
10 changes: 5 additions & 5 deletions src/decimojo/biguint/biguint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ struct BigUInt(
return Self()

var list_of_words = List[UInt32]()
var remainder: Int = value
var quotient: Int
var remainder: UInt = value
var quotient: UInt

while remainder != 0:
quotient = remainder // 1_000_000_000
Expand Down Expand Up @@ -904,7 +904,7 @@ struct BigUInt(
# " of UInt128 (340282366920938463463374607431768211455)"
# )

var result: UInt128 = 0
var result: UInt128

if len(self.words) == 1:
result = self.words._data.load[width=1]().cast[DType.uint128]()
Expand Down Expand Up @@ -1295,11 +1295,11 @@ struct BigUInt(
# Other dunders
# ===------------------------------------------------------------------=== #

fn __merge_with__[other_type: __type_of(BigInt)](self) -> BigInt:
fn __merge_with__[other_type: type_of(BigInt)](self) -> BigInt:
"Merges this BigUInt with a BigInt into a BigInt."
return BigInt(self)

fn __merge_with__[other_type: __type_of(BigDecimal)](self) -> BigDecimal:
fn __merge_with__[other_type: type_of(BigDecimal)](self) -> BigDecimal:
"Merges this BigUInt with a BigDecimal into a BigDecimal."
return BigDecimal(self)

Expand Down
6 changes: 3 additions & 3 deletions src/decimojo/decimal128/decimal128.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ struct Decimal128(

var flags: UInt32 = 0
flags |= (scale << Self.SCALE_SHIFT) & Self.SCALE_MASK
flags |= sign << 31
flags |= UInt32(sign) << 31

return Self(low, mid, high, flags)

Expand Down Expand Up @@ -489,7 +489,7 @@ struct Decimal128(

var result = UnsafePointer(to=value).bitcast[Decimal128]()[]
result.flags |= (scale << Self.SCALE_SHIFT) & Self.SCALE_MASK
result.flags |= sign << 31
result.flags |= UInt32(sign) << 31

return result

Expand Down Expand Up @@ -528,7 +528,7 @@ struct Decimal128(
if value_bytes_len == 0:
return Decimal128.ZERO()

if value_bytes_len != value_string_slice.char_length():
if value_bytes_len != Int(value_string_slice.char_length()):
raise Error(
String(
"There are invalid characters in decimal128 string: {}"
Expand Down
2 changes: 1 addition & 1 deletion src/decimojo/str.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn parse_numeric_string(
if value_bytes_len == 0:
raise Error("Error in `parse_numeric_string`: Empty string.")

if value_bytes_len != value_string_slice.char_length():
if value_bytes_len != Int(value_string_slice.char_length()):
raise Error(
String(
"There are invalid characters in the string of the number: {}"
Expand Down
66 changes: 50 additions & 16 deletions tests/bigdecimal/test_bigdecimal_arithmetics.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ fn test_bigdecimal_arithmetics() raises:
var toml = parse_file(file_path)
var test_cases: List[TestCase]

print("------------------------------------------------------")
print("Testing BigDecimal addition...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal addition...")
# print("------------------------------------------------------")

test_cases = load_test_cases(toml, "addition_tests")
count_wrong = 0
for test_case in test_cases:
var result = BDec(test_case.a) + BDec(test_case.b)
try:
Expand All @@ -47,13 +48,21 @@ fn test_bigdecimal_arithmetics() raises:
pydecimal.Decimal(test_case.a)
+ pydecimal.Decimal(test_case.b)
),
"\n",
)
count_wrong += 1
testing.assert_equal(
count_wrong,
0,
"Some test cases failed. See above for details.",
)

print("------------------------------------------------------")
print("Testing BigDecimal subtraction...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal subtraction...")
# print("------------------------------------------------------")

test_cases = load_test_cases(toml, "subtraction_tests")
count_wrong = 0
for test_case in test_cases:
var result = BDec(test_case.a) - BDec(test_case.b)
try:
Expand All @@ -74,13 +83,21 @@ fn test_bigdecimal_arithmetics() raises:
pydecimal.Decimal(test_case.a)
- pydecimal.Decimal(test_case.b)
),
"\n",
)
count_wrong += 1
testing.assert_equal(
count_wrong,
0,
"Some test cases failed. See above for details.",
)

print("------------------------------------------------------")
print("Testing BigDecimal multiplication...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal multiplication...")
# print("------------------------------------------------------")

test_cases = load_test_cases(toml, "multiplication_tests")
count_wrong = 0
for test_case in test_cases:
var result = BDec(test_case.a) * BDec(test_case.b)
try:
Expand All @@ -101,15 +118,25 @@ fn test_bigdecimal_arithmetics() raises:
pydecimal.Decimal(test_case.a)
* pydecimal.Decimal(test_case.b)
),
"\n",
)
count_wrong += 1
testing.assert_equal(
count_wrong,
0,
"Some test cases failed. See above for details.",
)

print("------------------------------------------------------")
print("Testing BigDecimal division...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal division...")
# print("------------------------------------------------------")

test_cases = load_test_cases(toml, "division_tests")
count_wrong = 0
for test_case in test_cases:
var result = BDec(test_case.a) / BDec(test_case.b)
var result = BDec(test_case.a).true_divide(
BDec(test_case.b), precision=28
)
try:
testing.assert_equal(
lhs=String(result),
Expand All @@ -128,12 +155,19 @@ fn test_bigdecimal_arithmetics() raises:
pydecimal.Decimal(test_case.a)
/ pydecimal.Decimal(test_case.b)
),
"\n",
)
count_wrong += 1
testing.assert_equal(
count_wrong,
0,
"Some test cases failed. See above for details.",
)


fn main() raises:
print("Running BigDecimal arithmetic tests")
# print("Running BigDecimal arithmetic tests")

test_bigdecimal_arithmetics()
testing.TestSuite.discover_tests[__functions_in_module()]().run()

print("All BigDecimal arithmetic tests passed!")
# print("All BigDecimal arithmetic tests passed!")
48 changes: 24 additions & 24 deletions tests/bigdecimal/test_bigdecimal_compare.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ fn test_bigdecimal_compare() raises:
var toml = parse_file(file_path)
var test_cases: List[TestCase]

print("------------------------------------------------------")
print("Testing BigDecimal compare_absolute...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal compare_absolute...")
# print("------------------------------------------------------")

test_cases = load_test_cases(toml, "compare_absolute_tests")
for test_case in test_cases:
Expand All @@ -29,9 +29,9 @@ fn test_bigdecimal_compare() raises:
msg=test_case.description,
)

print("------------------------------------------------------")
print("Testing BigDecimal > operator...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal > operator...")
# print("------------------------------------------------------")
test_cases = load_test_cases(toml, "greater_than_tests")
for test_case in test_cases:
var result = BDec(test_case.a) > BDec(test_case.b)
Expand All @@ -41,9 +41,9 @@ fn test_bigdecimal_compare() raises:
msg=test_case.description,
)

print("------------------------------------------------------")
print("Testing BigDecimal < operator...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal < operator...")
# print("------------------------------------------------------")
test_cases = load_test_cases(toml, "less_than_tests")
for test_case in test_cases:
var result = BDec(test_case.a) < BDec(test_case.b)
Expand All @@ -53,9 +53,9 @@ fn test_bigdecimal_compare() raises:
msg=test_case.description,
)

print("------------------------------------------------------")
print("Testing BigDecimal >= operator...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal >= operator...")
# print("------------------------------------------------------")
test_cases = load_test_cases(toml, "greater_than_or_equal_tests")
for test_case in test_cases:
var result = BDec(test_case.a) >= BDec(test_case.b)
Expand All @@ -65,9 +65,9 @@ fn test_bigdecimal_compare() raises:
msg=test_case.description,
)

print("------------------------------------------------------")
print("Testing BigDecimal <= operator...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal <= operator...")
# print("------------------------------------------------------")
test_cases = load_test_cases(toml, "less_than_or_equal_tests")
for test_case in test_cases:
var result = BDec(test_case.a) <= BDec(test_case.b)
Expand All @@ -77,9 +77,9 @@ fn test_bigdecimal_compare() raises:
msg=test_case.description,
)

print("------------------------------------------------------")
print("Testing BigDecimal == operator...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal == operator...")
# print("------------------------------------------------------")
test_cases = load_test_cases(toml, "equal_tests")
for test_case in test_cases:
var result = BDec(test_case.a) == BDec(test_case.b)
Expand All @@ -89,9 +89,9 @@ fn test_bigdecimal_compare() raises:
msg=test_case.description,
)

print("------------------------------------------------------")
print("Testing BigDecimal != operator...")
print("------------------------------------------------------")
# print("------------------------------------------------------")
# print("Testing BigDecimal != operator...")
# print("------------------------------------------------------")
test_cases = load_test_cases(toml, "not_equal_tests")
for test_case in test_cases:
var result = BDec(test_case.a) != BDec(test_case.b)
Expand All @@ -103,9 +103,9 @@ fn test_bigdecimal_compare() raises:


fn main() raises:
print("Running BigDecimal comparison tests")
# print("Running BigDecimal comparison tests")

# Run compare_absolute tests
test_bigdecimal_compare()
testing.TestSuite.discover_tests[__functions_in_module()]().run()

print("All BigDecimal comparison tests passed!")
# print("All BigDecimal comparison tests passed!")
Loading