Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gkr_nonnative intial review #1162

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
522 changes: 522 additions & 0 deletions frontend/variable.go

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions std/fiat-shamir/settings.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ package fiatshamir
import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/emulated"
)

type Settings struct {
@@ -12,6 +13,13 @@ type Settings struct {
Hash hash.FieldHasher
}

type SettingsFr[FR emulated.FieldParams] struct {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
Transcript *Transcript
Prefix string
BaseChallenges []emulated.Element[FR]
Hash hash.FieldHasher
}

func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings {
return Settings{
Transcript: transcript,
@@ -20,9 +28,24 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro
}
}

func WithTranscriptFr[FR emulated.FieldParams](transcript *Transcript, prefix string, baseChallenges ...emulated.Element[FR]) SettingsFr[FR] {
return SettingsFr[FR]{
Transcript: transcript,
Prefix: prefix,
BaseChallenges: baseChallenges,
}
}

func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings {
return Settings{
BaseChallenges: baseChallenges,
Hash: hash,
}
}

func WithHashFr[FR emulated.FieldParams](hash hash.FieldHasher, baseChallenges ...emulated.Element[FR]) SettingsFr[FR] {
return SettingsFr[FR]{
BaseChallenges: baseChallenges,
Hash: hash,
}
}
1 change: 1 addition & 0 deletions std/gkr/gkr.go
Original file line number Diff line number Diff line change
@@ -308,6 +308,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof,
claims := newClaimsManager(c, assignment)

var firstChallenge []frontend.Variable
// why no bind values here?
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
if err != nil {
return err
12 changes: 11 additions & 1 deletion std/gkr/gkr_test.go
Original file line number Diff line number Diff line change
@@ -8,8 +8,11 @@ import (
"reflect"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/profile"
fiatshamir "github.com/consensys/gnark/std/fiat-shamir"
"github.com/consensys/gnark/std/polynomial"
"github.com/consensys/gnark/test"
@@ -74,6 +77,14 @@ func generateTestVerifier(path string, options ...option) func(t *testing.T) {
TestCaseName: path,
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Omit debugging code.

p:= profile.Start()
frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, validCircuit)
p.Stop()

fmt.Println(p.NbConstraints())
fmt.Println(p.Top())
//r1cs.CheckUnconstrainedWires()

invalidCircuit := &GkrVerifierCircuit{
Input: make([][]frontend.Variable, len(testCase.Input)),
Output: make([][]frontend.Variable, len(testCase.Output)),
@@ -327,7 +338,6 @@ func TestLoadCircuit(t *testing.T) {
assert.Equal(t, []*Wire{}, c[0].Inputs)
assert.Equal(t, []*Wire{&c[0]}, c[1].Inputs)
assert.Equal(t, []*Wire{&c[1]}, c[2].Inputs)

}

func TestTopSortTrivial(t *testing.T) {
23 changes: 23 additions & 0 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ import (

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/std/math/bits"
)

// Element defines an element in the ring of integers modulo n. The integer
@@ -106,3 +107,25 @@ func (e *Element[T]) copy() *Element[T] {
r.internal = e.internal
return &r
}

// newInternalElement sets the limbs and overflow. Given as a function for later
// possible refactor.
func newInternalElement[T FieldParams](limbs []frontend.Variable, overflow uint) *Element[T] {
return &Element[T]{Limbs: limbs, overflow: overflow, internal: true}
}

// FromBits returns a new Element given the bits is little-endian order.
func FromBits[FR FieldParams](api frontend.API, bs ...frontend.Variable) *Element[FR] {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
var fParams FR
nbLimbs := (uint(len(bs)) + fParams.BitsPerLimb() - 1) / fParams.BitsPerLimb()
limbs := make([]frontend.Variable, nbLimbs)
for i := uint(0); i < nbLimbs-1; i++ {
limbs[i] = bits.FromBinary(api, bs[i*fParams.BitsPerLimb():(i+1)*fParams.BitsPerLimb()])
}
limbs[nbLimbs-1] = bits.FromBinary(api, bs[(nbLimbs-1)*fParams.BitsPerLimb():])
return newInternalElement[FR](limbs, 0)
}

func CreateConstElement[T FieldParams](v interface{}) *Element[T] {
return newConstElement[T](v)
}
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions std/math/emulated/field_mul.go
Original file line number Diff line number Diff line change
@@ -414,6 +414,16 @@ func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] {
return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b)
}

// // MulAcc computes a*b and reduces it modulo the field order. The returned Element
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
// // has default number of limbs and zero overflow. If the result wouldn't fit
// // into Element, then locally reduces the inputs first. Doesn't mutate inputs.
// //
// // For multiplying by a constant, use [Field[T].MulConst] method which is more
// // efficient.
// func (f *Field[T]) MulAcc(a, b *Element[T], c *Element[T]) *Element[T] {
// return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b)
// }

// MulMod computes a*b and reduces it modulo the field order. The returned Element
// has default number of limbs and zero overflow.
//
16 changes: 16 additions & 0 deletions std/math/polynomial/polynomial.go
Original file line number Diff line number Diff line change
@@ -22,6 +22,10 @@ type Univariate[FR emulated.FieldParams] []emulated.Element[FR]
// coefficients.
type Multilinear[FR emulated.FieldParams] []emulated.Element[FR]

func (ml *Multilinear[FR]) NumVars() int {
return bits.Len(uint(len(*ml) - 1))
}

func valueOf[FR emulated.FieldParams](univ []*big.Int) []emulated.Element[FR] {
ret := make([]emulated.Element[FR], len(univ))
for i := range univ {
@@ -89,6 +93,18 @@ func New[FR emulated.FieldParams](api frontend.API) (*Polynomial[FR], error) {
}, nil
}

func (p *Polynomial[FR]) Mul(a, b *emulated.Element[FR]) *emulated.Element[FR] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not need to have Mul, Add and AssertIsEqual as methods on Polynomial. You can directly call emulated.Field methods Mul, Add, AssertIsEqual on your inputs. I'd recommend removing methods here.

Copy link
Author

@amit0365 amit0365 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Now creating a new instance of field with emulated.Field[FR]{} and using the api from there.

return p.f.Mul(a, b)
}

func (p *Polynomial[FR]) Add(a, b *emulated.Element[FR]) *emulated.Element[FR] {
return p.f.Add(a, b)
}

func (p *Polynomial[FR]) AssertIsEqual(a, b *emulated.Element[FR]) {
p.f.AssertIsEqual(a, b)
}

// EvalUnivariate evaluates univariate polynomial at a point at. It returns the
// evaluation. The method does not mutate the inputs.
func (p *Polynomial[FR]) EvalUnivariate(P Univariate[FR], at *emulated.Element[FR]) *emulated.Element[FR] {
106 changes: 104 additions & 2 deletions std/polynomial/polynomial.go
Original file line number Diff line number Diff line change
@@ -4,13 +4,78 @@ import (
"math/bits"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark-crypto/utils"
)

type Polynomial []frontend.Variable
type MultiLin []frontend.Variable

var minFoldScaledLogSize = 16

func FromSlice(s []frontend.Variable) []*frontend.Variable {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
r := make([]*frontend.Variable, len(s))
for i := range s {
r[i] = &s[i]
}
return r
}

// FromSliceReferences maps slice of emulated element references to their values.
func FromSliceReferences(in []*frontend.Variable) []frontend.Variable {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
r := make([]frontend.Variable, len(in))
for i := range in {
r[i] = *in[i]
}
return r
}

func _clone(m MultiLin, p *Pool) MultiLin {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
if p == nil {
return m.Clone()
} else {
return p.Clone(m)
}
}

func _dump(m MultiLin, p *Pool) {
if p != nil {
p.Dump(m)
}
}

// Evaluate assumes len(m) = 1 << len(at)
// it doesn't modify m
func (m MultiLin) EvaluatePool(api frontend.API, at []frontend.Variable, pool *Pool) frontend.Variable {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
_m := _clone(m, pool)

/*minFoldScaledLogSize := 16
if api is r1cs {
minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs
}*/

scaleCorrectionFactor := frontend.Variable(1)
// at each iteration fold by at[i]
for len(_m) > 1 {
if len(_m) >= minFoldScaledLogSize {
scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0]))
} else {
_m.Fold(api, at[0])
}
_m = _m[:len(_m)/2]
at = at[1:]
}

if len(at) != 0 {
panic("incompatible evaluation vector size")
}

result := _m[0]

_dump(_m, pool)

return api.Mul(result, scaleCorrectionFactor)
}

// Evaluate assumes len(m) = 1 << len(at)
// it doesn't modify m
func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable {
@@ -27,7 +92,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va
if len(_m) >= minFoldScaledLogSize {
scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0]))
} else {
_m.fold(api, at[0])
_m.Fold(api, at[0])
}
_m = _m[:len(_m)/2]
at = at[1:]
@@ -42,7 +107,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va

// fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size
// WARNING: The user should halve m themselves after the call
func (m MultiLin) fold(api frontend.API, at frontend.Variable) {
func (m MultiLin) Fold(api frontend.API, at frontend.Variable) {
zero := m[:len(m)/2]
one := m[len(m)/2:]
for j := range zero {
@@ -51,6 +116,43 @@ func (m MultiLin) fold(api frontend.API, at frontend.Variable) {
}
}

func (m *MultiLin) FoldParallel(api frontend.API, r frontend.Variable) utils.Task {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need it - circuit compilation happen sequentially anyway and you cannot parallelize it.

mid := len(*m) / 2
bottom, top := (*m)[:mid], (*m)[mid:]

*m = bottom

return func(start, end int) {
var t frontend.Variable // no need to update the top part
for i := start; i < end; i++ {
// table[i] ← table[i] + r (table[i + mid] - table[i])
t = api.Sub(&top[i], &bottom[i])
t = api.Mul(&t, &r)
bottom[i] = api.Add(&bottom[i], &t)
}
}
}

// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0]
func (m *MultiLin) Eq(api frontend.API, q []frontend.Variable) {
n := len(q)

if len(*m) != 1<<n {
panic("destination must have size 2 raised to the size of source")
}

//At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
for j := 0; j < (1 << i); j++ {
j0 := j << (n - i) // bᵢ₊₁ = 0
j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1
(*m)[j1] = api.Mul((*m)[j1], q[i]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
(*m)[j0] = api.Sub((*m)[j0], (*m)[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
}
}
}

// foldScaled(m, at) = fold(m, at) / (1 - at)
// it returns 1 - at, for convenience
func (m MultiLin) foldScaled(api frontend.API, at frontend.Variable) (denom frontend.Variable) {
2 changes: 1 addition & 1 deletion std/polynomial/polynomial_test.go
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ func (c *foldMultiLinCircuit) Define(api frontend.API) error {
return errors.New("folding size mismatch")
}
m := MultiLin(c.M)
m.fold(api, c.At)
m.Fold(api, c.At)
for i := range c.Result {
api.AssertIsEqual(m[i], c.Result[i])
}
Loading