-
Notifications
You must be signed in to change notification settings - Fork 1
/
field.mojo
80 lines (59 loc) · 2.18 KB
/
field.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Ported from starkware101 tutorial's Python version[1].
# [1]: https://github.com/starkware-industries/stark101
@value
struct FieldElement(CollectionElement, Stringable):
# ------------------------------------------------------------------------
# Static methods
@staticmethod
fn characteristic() -> Int:
return 3 * 2**30 + 1
@staticmethod
fn primitive_element() -> Self:
return Self(5)
# ------------------------------------------------------------------------
# Fields and basic methods
var val: Int
fn __init__(inout self, val: Int):
self.val = val % Self.characteristic()
fn __str__(self) -> String:
return String(self.val)
# ------------------------------------------------------------------------
# Arithmetic operators
fn __eq__(self, other: Self) -> Bool:
return self.val == other.val
fn __neg__(self) -> Self:
return Self(-self.val)
fn __add__(self, other: Self) -> Self:
return Self(self.val + other.val)
fn __sub__(self, other: Self) -> Self:
return Self(self.val - other.val)
fn __mul__(self, other: Self) -> Self:
return Self(self.val * other.val)
fn __imul__(inout self, other: Self):
self.val = (self.val * other.val) % Self.characteristic()
fn inverse(self) -> Self:
var t: Int = 0
var new_t: Int = 1
var r = FieldElement.characteristic()
var new_r = self.val
while new_r != 0:
let quotient = r // new_r
t, new_t = new_t, (t - (quotient * new_t))
r, new_r = new_r, (r - (quotient * new_r))
debug_assert( r == 1, "inverse() failed" )
return Self(t)
fn __truediv__(self, other: Self) -> Self:
return self * other.inverse()
fn __pow__(self, _n: Int) -> Self:
debug_assert( _n >= 0, "unexpected negative argument" )
var cur_pow = self
var res = FieldElement(1)
var n = _n
while n > 0:
if n % 2 != 0:
res *= cur_pow
n = n // 2
cur_pow *= cur_pow
return res
fn __pow__(self, n: Self) -> Self:
return self.__pow__(n.val)