diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 57c47fb8ed..0fec685f94 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -54,10 +54,57 @@ func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) f } func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { + + if fastTrack := builder.mulAccFastTrack(a, b, c); fastTrack != nil { + return fastTrack + } + // TODO can we do better here to limit allocations? return builder.Add(a, builder.Mul(b, c)) } +// special case for when a/c is constant +// let a = a' * α, b = b' * β, c = c' * α +// then a + b * c = a' * α + (b' * c') (β * α) +// thus qL = a', qR = 0, qM = b'c' +func (builder *builder) mulAccFastTrack(a, b, c frontend.Variable) frontend.Variable { + var ( + aVar, bVar, cVar expr.Term + ok bool + ) + if aVar, ok = a.(expr.Term); !ok { + return nil + } + if bVar, ok = b.(expr.Term); !ok { + return nil + } + if cVar, ok = c.(expr.Term); !ok { + return nil + } + + if aVar.VID == bVar.VID { + bVar, cVar = cVar, bVar + } + + if aVar.VID != cVar.VID { + return nil + } + + res := builder.newInternalVariable() + builder.addPlonkConstraint(sparseR1C{ + xa: aVar.VID, + xb: bVar.VID, + xc: res.VID, + qL: aVar.Coeff, + qR: constraint.Element{}, + qO: builder.tMinusOne, + qM: builder.cs.Mul(bVar.Coeff, cVar.Coeff), + qC: constraint.Element{}, + commitment: 0, + }) + return res +} + // neg returns -in func (builder *builder) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) diff --git a/frontend/cs/scs/duplicate_test.go b/frontend/cs/scs/api_test.go similarity index 85% rename from frontend/cs/scs/duplicate_test.go rename to frontend/cs/scs/api_test.go index cf7ad8f169..2c57fed7fa 100644 --- a/frontend/cs/scs/duplicate_test.go +++ b/frontend/cs/scs/api_test.go @@ -154,3 +154,29 @@ func TestExistDiv0(t *testing.T) { assert.NoError(err) _ = solution } + +type mulAccFastTrackCircuit struct { + A, B frontend.Variable + Res frontend.Variable +} + +func (c *mulAccFastTrackCircuit) Define(api frontend.API) error { + r := api.MulAcc(api.Mul(c.A, 1), c.B, c.A) + api.AssertIsEqual(r, c.Res) + return nil +} + +func TestMulAccFastTrack(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &mulAccFastTrackCircuit{}) + assert.NoError(err) + assert.Equal(2, ccs.GetNbConstraints()) + w, err := frontend.NewWitness(&mulAccFastTrackCircuit{ + A: 11, B: 21, + Res: 242, + }, ecc.BN254.ScalarField()) + assert.NoError(err) + solution, err := ccs.Solve(w) + assert.NoError(err) + _ = solution +}