diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 387e3b0..4c1a8b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,64 @@ jobs: - name: Format run: cargo +nightly fmt --all -- --check + build: + name: Build + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - nightly + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust (${{ matrix.rust }}) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + - name: Add SSH key for private repos + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + + - name: Build + run: cargo build + + build-no-std: + name: Build (no-std) + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - nightly + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust (${{ matrix.rust }}) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + - name: Add SSH key for private repos + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + + - name: Build + run: cargo build --no-default-features + test: name: Test runs-on: ubuntu-latest @@ -44,15 +102,13 @@ jobs: toolchain: ${{ matrix.rust }} override: true - - uses: actions/cache@v3 + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + - name: Add SSH key for private repos + uses: webfactory/ssh-agent@v0.9.0 with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} - name: Test - uses: actions-rs/cargo@v1 - with: - command: test + run: cargo test diff --git a/Cargo.lock b/Cargo.lock index 368aba0..5fd48f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" dependencies = [ "cfg-if", "once_cell", @@ -44,7 +44,7 @@ checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "ark-bn254" version = "0.4.0" -source = "git+https://github.com/arkworks-rs/algebra/#228787b5ab87139dc2a79359d2f6b25237f46dac" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ec", "ark-ff", @@ -63,9 +63,9 @@ dependencies = [ "ark-serialize", "ark-snark", "ark-std", - "blake2 0.10.6", + "blake2", "derivative", - "digest 0.10.7", + "digest", "sha2", ] @@ -76,13 +76,13 @@ source = "git+https://github.com/HungryCatsStudio/crypto-primitives?branch=absor dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.50", ] [[package]] name = "ark-ec" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ff", "ark-poly", @@ -90,16 +90,17 @@ dependencies = [ "ark-std", "derivative", "hashbrown", - "itertools 0.12.0", + "itertools 0.12.1", "num-bigint", "num-traits", + "rayon", "zeroize", ] [[package]] name = "ark-ff" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ff-asm", "ark-ff-macros", @@ -107,52 +108,40 @@ dependencies = [ "ark-std", "arrayvec", "derivative", - "digest 0.10.7", - "itertools 0.12.0", + "digest", + "itertools 0.12.1", "num-bigint", "num-traits", "paste", + "rayon", "zeroize", ] [[package]] name = "ark-ff-asm" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "quote", - "syn 2.0.48", + "syn 2.0.50", ] [[package]] name = "ark-ff-macros" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "num-bigint", "num-traits", "proc-macro2", "quote", - "syn 2.0.48", -] - -[[package]] -name = "ark-linear-sumcheck" -version = "0.4.0" -source = "git+https://github.com/arkworks-rs/sumcheck/#956fdaa2b80ff72cda2eafefda3f62a57589ddbd" -dependencies = [ - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "blake2 0.9.2", - "hashbrown", + "syn 2.0.50", ] [[package]] name = "ark-pcs-bench-templates" version = "0.4.0" -source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=brakedown-com-absorb#95fc96c5af94ad3cb6b83a5133810a2954336bc7" +source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#dfdd8e87d3df9059816dd7cec16ade0f4ac0623a" dependencies = [ "ark-crypto-primitives", "ark-ec", @@ -169,29 +158,33 @@ dependencies = [ [[package]] name = "ark-poly" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ff", "ark-serialize", "ark-std", "derivative", "hashbrown", + "rayon", ] [[package]] name = "ark-poly-commit" version = "0.4.0" -source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=brakedown-com-absorb#95fc96c5af94ad3cb6b83a5133810a2954336bc7" +source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#dfdd8e87d3df9059816dd7cec16ade0f4ac0623a" dependencies = [ "ark-crypto-primitives", "ark-ec", "ark-ff", "ark-poly", + "ark-relations", "ark-serialize", "ark-std", "derivative", - "digest 0.10.7", + "digest", + "merlin", "num-traits", + "rayon", ] [[package]] @@ -203,27 +196,28 @@ dependencies = [ "ark-ff", "ark-std", "tracing", + "tracing-subscriber", ] [[package]] name = "ark-serialize" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-serialize-derive", "ark-std", - "digest 0.10.7", + "digest", "num-bigint", ] [[package]] name = "ark-serialize-derive" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.50", ] [[package]] @@ -246,6 +240,22 @@ checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" dependencies = [ "num-traits", "rand", + "rayon", +] + +[[package]] +name = "ark-sumcheck" +version = "0.4.0" +source = "git+ssh://git@github.com/HungryCatsStudio/sumcheck-private.git#721fb56acd6aba79333d8862a0067b01023bc845" +dependencies = [ + "ark-crypto-primitives", + "ark-ff", + "ark-poly", + "ark-poly-commit", + "ark-serialize", + "ark-std", + "hashbrown", + "rayon", ] [[package]] @@ -260,24 +270,13 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "blake2" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" -dependencies = [ - "crypto-mac", - "digest 0.9.0", - "opaque-debug", -] - [[package]] name = "blake2" version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest 0.10.7", + "digest", ] [[package]] @@ -289,6 +288,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cast" version = "0.3.0" @@ -330,18 +335,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.0" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80c21025abd42669a92efc996ef13cfb2c5c627858421ea58d5c3b331a6c134f" +checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.0" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458bf1f341769dfcf849846f65dffdf9146daa56bcd2a47cb4e1de9915567c99" +checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb" dependencies = [ "anstyle", "clap_lex", @@ -396,6 +401,31 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + [[package]] name = "crunchy" version = "0.2.2" @@ -412,16 +442,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "crypto-mac" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" -dependencies = [ - "generic-array", - "subtle", -] - [[package]] name = "derivative" version = "2.2.0" @@ -433,15 +453,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "digest" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" -dependencies = [ - "generic-array", -] - [[package]] name = "digest" version = "0.10.7" @@ -455,9 +466,9 @@ dependencies = [ [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "generic-array" @@ -517,9 +528,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] @@ -530,11 +541,20 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libm" @@ -548,6 +568,18 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "merlin" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d" +dependencies = [ + "byteorder", + "keccak", + "rand_core", + "zeroize", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -561,19 +593,18 @@ dependencies = [ [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -591,12 +622,6 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "paste" version = "1.0.14" @@ -659,6 +684,26 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +[[package]] +name = "rayon" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "regex" version = "1.10.3" @@ -690,9 +735,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -705,29 +750,29 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.195" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.50", ] [[package]] name = "serde_json" -version = "1.0.111" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -742,7 +787,7 @@ checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", - "digest 0.10.7", + "digest", ] [[package]] @@ -764,9 +809,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" dependencies = [ "proc-macro2", "quote", @@ -798,6 +843,19 @@ name = "tracing-core" version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] [[package]] name = "typenum" @@ -811,6 +869,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "verifiaml" version = "0.1.0" @@ -819,13 +883,14 @@ dependencies = [ "ark-crypto-primitives", "ark-ec", "ark-ff", - "ark-linear-sumcheck", "ark-pcs-bench-templates", "ark-poly", "ark-poly-commit", "ark-serialize", "ark-std", - "blake2 0.10.6", + "ark-sumcheck", + "blake2", + "rayon", "serde_json", ] @@ -887,9 +952,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -902,45 +967,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" [[package]] name = "zerocopy" @@ -959,7 +1024,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.50", ] [[package]] @@ -979,5 +1044,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.50", ] diff --git a/Cargo.toml b/Cargo.toml index 35fd49c..cd78ce4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,19 +11,25 @@ ark-serialize = { version = "^0.4.0", default-features = false, features = [ "de ark-poly = {version = "^0.4.0", default-features = false } ark-poly-commit = {version = "^0.4.0", default-features = false } ark-crypto-primitives = {version = "^0.4.0", default-features = false } -ark-linear-sumcheck = { git = "https://github.com/arkworks-rs/sumcheck/", default-features = false } +ark-sumcheck = { git = "ssh://git@github.com/HungryCatsStudio/sumcheck-private.git", default-features = false } +rayon = { version = "1.5", default-features = false, optional = true } [dev-dependencies] ark-bn254 = { version = "^0.4.0", default-features = false, features = [ "curve" ] } blake2 = { version = "0.10", default-features = false } serde_json = "1.0.108" -ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "brakedown-com-absorb" } +ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb", default-features = false } [patch.crates-io] ark-ff = { git = "https://github.com/arkworks-rs/algebra/" } ark-ec = { git = "https://github.com/arkworks-rs/algebra/" } ark-serialize = { git = "https://github.com/arkworks-rs/algebra/" } ark-poly = { git = "https://github.com/arkworks-rs/algebra/" } -ark-poly-commit = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "brakedown-com-absorb" } +ark-poly-commit = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb" } ark-crypto-primitives = { git = "https://github.com/HungryCatsStudio/crypto-primitives", branch = "absorb"} ark-bn254 = { git = "https://github.com/arkworks-rs/algebra/" } + +[features] +default = [ "std", "parallel" ] +std = [ "ark-ff/std", "ark-ec/std", "ark-poly/std", "ark-serialize/std", "ark-crypto-primitives/std", "ark-poly-commit/std", "ark-sumcheck/std" ] +parallel = [ "std", "ark-ff/parallel", "ark-ec/parallel", "ark-poly/parallel", "ark-std/parallel", "ark-poly-commit/parallel", "ark-sumcheck/parallel", "rayon" ] \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 4d537be..1d4523f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,6 @@ pub(crate) mod model; pub(crate) mod quantization; - -#[cfg(test)] -pub(crate) mod pcs_types; +pub(crate) mod utils; trait Commitment {} diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 15c9f2b..5176c18 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -1,9 +1,7 @@ use crate::{ model::{ - nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, - qarray::QArray, - Model, Poly, - }, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + isolated_verification::verify_inference, nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::{QArray, QTypeArray}, Model, Poly + }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Ligero, test_sponge::test_sponge} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -14,11 +12,12 @@ use ark_ff::PrimeField; mod input; mod parameters; +use ark_std::test_rng; use input::*; use parameters::*; const INPUT_DIMS: &[usize] = &[28, 28]; -const OUTPUT_DIMS: &[usize] = &[10]; +const OUTPUT_DIMS: usize = 10; // TODO this is incorrect now that we have switched to logs fn build_simple_perceptron_mnist() -> Model @@ -35,12 +34,12 @@ where let bmm: BMMNode = BMMNode::new( WEIGHTS.to_vec(), BIAS.to_vec(), - (flat_dim, OUTPUT_DIMS[0]), + (flat_dim, OUTPUT_DIMS), Z_I, ); let req_bmm: RequantiseBMMNode = RequantiseBMMNode::new( - OUTPUT_DIMS[0], + OUTPUT_DIMS, S_I, Z_I, S_W, @@ -57,13 +56,39 @@ where } #[test] -fn run_simple_perceptron_mnist() { +fn run_native_simple_perceptron_mnist() { /**** Change here ****/ let input = NORMALISED_INPUT_TEST_150; let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; /**********************/ - let perceptron = build_simple_perceptron_mnist::, Brakedown>(); + let perceptron = build_simple_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let output_i8 = perceptron.evaluate(input_i8); + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values(), expected_output); +} + + +#[test] +fn run_padded_simple_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; + /**********************/ + + let perceptron = build_simple_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() @@ -80,3 +105,107 @@ fn run_simple_perceptron_mnist() { println!("Output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values(), expected_output); } + +#[test] +fn prove_inference_simple_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; + /**********************/ + + let perceptron = build_simple_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, _) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + //let (hidden_nodes, com_states) = perceptron.commit(&ck, None).iter().unzip(); + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS], expected_output); +} + + +#[test] +fn verify_inference_simple_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; + /**********************/ + + let perceptron = build_simple_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let mut sponge: PoseidonSponge = test_sponge(); + + assert!(verify_inference( + &vk, + &mut sponge, + &perceptron, + &node_coms, + inference_proof + )); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS], expected_output); +} diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index ba0ad9d..c95326d 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -1,9 +1,7 @@ use crate::{ model::{ - nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, - qarray::QArray, - Model, Poly, - }, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + isolated_verification::verify_inference, nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::{QArray, QTypeArray}, Model, Poly + }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Ligero, test_sponge::test_sponge} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -14,6 +12,7 @@ use ark_ff::PrimeField; mod input; mod parameters; +use ark_std::test_rng; use input::*; use parameters::*; @@ -80,13 +79,38 @@ where } #[test] -fn run_two_layer_perceptron_mnist() { +fn run_native_two_layer_perceptron_mnist() { /**** Change here ****/ let input = NORMALISED_INPUT_TEST_150; let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; /**********************/ - let perceptron = build_two_layer_perceptron_mnist::, Brakedown>(); + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let output_i8 = perceptron.evaluate(input_i8); + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values(), expected_output); +} + +#[test] +fn run_padded_two_layer_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; + /**********************/ + + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() @@ -103,3 +127,105 @@ fn run_two_layer_perceptron_mnist() { println!("Output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values(), expected_output); } + +#[test] +fn prove_inference_two_layer_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; + /**********************/ + + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, _) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIM], expected_output); +} + +#[test] +fn verify_inference_two_layer_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; + /**********************/ + + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let mut sponge: PoseidonSponge = test_sponge(); + + assert!(verify_inference( + &vk, + &mut sponge, + &perceptron, + &node_coms, + inference_proof + )); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIM], expected_output); +} \ No newline at end of file diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs new file mode 100644 index 0000000..1d2f642 --- /dev/null +++ b/src/model/isolated_verification.rs @@ -0,0 +1,299 @@ +use std::vec; + +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; +use ark_ff::PrimeField; +use ark_poly::{DenseMultilinearExtension, Polynomial}; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; +use ark_std::log2; +use ark_sumcheck::ml_sumcheck::{ + protocol::{verifier::SubClaim, PolynomialInfo}, + MLSumcheck, +}; + +use crate::model::nodes::bmm::{BMMNodeCommitment, BMMNodeProof}; + +use super::{ + nodes::{Node, NodeCommitment, NodeProof}, + qarray::{QArray, QTypeArray}, + InferenceProof, Model, Poly, +}; + +fn verify_bmm_node( + vk: &PCS::VerifierKey, + sponge: &mut S, + node_com: &NodeCommitment, + input_com: &LabeledCommitment, + output_com: &LabeledCommitment, + proof: NodeProof, + padded_dims_log: (usize, usize), + input_zero_point: F, // This argument will not be here in the final code +) -> bool +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + let NodeCommitment::BMM(BMMNodeCommitment { + weight_com, + bias_com, + }) = node_com + else { + panic!("Expected BMMNodeCommitment") + }; + + let BMMNodeProof { + sumcheck_proof, + input_opening_proof, + input_opening_value, + weight_opening_proof, + weight_opening_value, + output_bias_opening_proof, + output_opening_value, + bias_opening_value, + } = match proof { + NodeProof::BMM(p) => p, + _ => panic!("Expected BMMNodeProof"), + }; + + // Squeezing random challenge r to bind the first variables of W^ to + let r: Vec = sponge.squeeze_field_elements(padded_dims_log.1); + + // The hypercube sum proved in sumcheck should be the difference between + // the output and the bias + let sumcheck_evaluation = output_opening_value - bias_opening_value; + + // Public information about the sumchecked polynomial + // g(x) = (input - zero_point)^(x) * W^(r, x), + let info = PolynomialInfo { + max_multiplicands: 2, + num_variables: padded_dims_log.0, + products: vec![(F::one(), vec![0, 1])], + }; + + // Verify the sumcheck proof for g and obtaining the oracle-call point s + // and claimed evaluation g(s) + let Ok(subclaim) = MLSumcheck::verify(&info, sumcheck_evaluation, &sumcheck_proof, sponge) + else { + return false; + }; + + let SubClaim { + point: oracle_point, + expected_evaluation: oracle_evaluation, + } = subclaim; + + // Verify g(s) agrees with the claims for (input - zero_point)^(s) and + // W^(r, s) + if oracle_evaluation != (input_opening_value - input_zero_point) * weight_opening_value { + return false; + } + + // Verify that the opening of input^ at s agrees with the claimed value for + // (input - zero_point)^(s) + // TODO possibly rng, not None + if !PCS::check( + vk, + [input_com], + &oracle_point, + [input_opening_value], + &input_opening_proof, + sponge, + None, + ) + .unwrap() + { + return false; + } + + // Verify the openings of W^ at r || s and b and o at r match the claimed + // values + // TODO possibly rng, not None + if !PCS::check( + vk, + [weight_com], + &r.clone().into_iter().chain(oracle_point).collect(), + [weight_opening_value], + &weight_opening_proof, + sponge, + None, + ) + .unwrap() + { + return false; + } + + PCS::check( + vk, + [output_com, bias_com], + &r, + [output_opening_value, bias_opening_value], + &output_bias_opening_proof, + sponge, + None, + ) + .unwrap() +} + +fn verify_node( + vk: &PCS::VerifierKey, + sponge: &mut S, + node_com: &NodeCommitment, + input_com: &LabeledCommitment, + output_com: &LabeledCommitment, + proof: NodeProof, + padded_dims_log: Option<(usize, usize)>, + input_zero_point: Option, +) -> bool +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + match node_com { + NodeCommitment::BMM(_) => verify_bmm_node( + vk, + sponge, + node_com, + input_com, + output_com, + proof, + padded_dims_log.unwrap(), + input_zero_point.unwrap(), + ), + _ => true, + } +} + +pub(crate) fn verify_inference( + vk: &PCS::VerifierKey, + sponge: &mut S, + model: &Model, + node_commitments: &Vec>, + inference_proof: InferenceProof, +) -> bool +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + let InferenceProof { + inputs_outputs, + node_value_commitments, + node_proofs, + opening_proofs, + } = inference_proof; + + // Absorb all commitments into the sponge + sponge.absorb(&node_value_commitments); + + // TODO Verify that all commited NIOs live in the right range (to be + // discussed) + + // Verify node proofs + for (((node, node_com), io_com), node_proof) in model + .nodes + .iter() + .zip(node_commitments.iter()) + .zip(node_value_commitments.windows(2)) + .zip(node_proofs.into_iter()) + { + // This will not be necessary in the actual code, as the BMM dimensions + // and zero point will be contained in the (possibly hidden) BMMNode + // and therefore won't be passed to the proving method + let (padded_dims_log, input_zero_point) = match node { + Node::BMM(bmm) => ( + Some(bmm.padded_dims_log()), + Some(F::from(bmm.input_zero_point())), + ), + _ => (None, None), + }; + + if !verify_node( + vk, + sponge, + node_com, + &io_com[0], + &io_com[1], + node_proof, + padded_dims_log, + input_zero_point, + ) { + return false; + } + } + + // Verifying model IO + // TODO maybe this can be made more efficient by not committing to the + // output nodes and instead working witht their plain values all along, + // but that would require messy node-by-node handling + let input_node_com = node_value_commitments.first().unwrap(); + let input_node_qarray = match &inputs_outputs[0] { + QTypeArray::S(i) => i, + _ => panic!("Model input should be QTypeArray::S"), + }; + let input_node_f: Vec = input_node_qarray + .values() + .iter() + .map(|x| F::from(*x)) + .collect(); + + let output_node_com = node_value_commitments.last().unwrap(); + // TODO maybe it's better to save this as F in the proof? + let output_node_f: Vec = match &inputs_outputs[1] { + QTypeArray::S(o) => o.values().iter().map(|x| F::from(*x)).collect(), + _ => panic!("Model output should be QTypeArray::S"), + }; + + // Absorb the model IO output and squeeze the challenge point + // Absorb the plain output and squeeze the challenge point + sponge.absorb(&input_node_f); + sponge.absorb(&output_node_f); + let input_challenge_point = sponge.squeeze_field_elements(log2(input_node_f.len()) as usize); + let output_challenge_point = sponge.squeeze_field_elements(log2(output_node_f.len()) as usize); + + // Verifying that the actual input was honestly padded with zeros + let padded_input_shape = input_node_qarray.shape().clone(); + let honestly_padded_input = input_node_qarray + .compact_resize(model.input_shape().clone(), 0) + .compact_resize(padded_input_shape, 0); + + if honestly_padded_input.values() != input_node_qarray.values() { + return false; + } + + // The verifier must evaluate the MLE given by the plain input values + let input_node_eval = + Poly::from_evaluations_vec(log2(input_node_f.len()) as usize, input_node_f) + .evaluate(&input_challenge_point); + let output_node_eval = + Poly::from_evaluations_vec(log2(output_node_f.len()) as usize, output_node_f) + .evaluate(&output_challenge_point); + + // The computed values should match the openings of the corresponding + // vectors + // TODO rng, None + if !PCS::check( + vk, + [input_node_com], + &input_challenge_point, + [input_node_eval], + &opening_proofs[0], + sponge, + None, + ) + .unwrap() + { + return false; + } + + PCS::check( + vk, + [output_node_com], + &output_challenge_point, + [output_node_eval], + &opening_proofs[1], + sponge, + None, + ) + .unwrap() +} diff --git a/src/model/mod.rs b/src/model/mod.rs index 4b3a50f..d2dec26 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -2,8 +2,8 @@ use ark_std::{log2, rand::RngCore}; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly::DenseMultilinearExtension; -use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use crate::model::nodes::{NodeOps, NodeOpsSNARK}; use crate::{model::nodes::Node, quantization::QSmallType}; @@ -15,26 +15,30 @@ use self::{ }; mod examples; +mod isolated_verification; mod nodes; mod qarray; -mod reshaping; pub(crate) type Poly = DenseMultilinearExtension; +pub(crate) type LabeledPoly = LabeledPolynomial>; pub(crate) struct InferenceProof where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - // Model output tensors - outputs: Vec, + // Model input and output tensors in plain + pub(crate) inputs_outputs: Vec, + + // Commitments to each of the node values + pub(crate) node_value_commitments: Vec>, // Proofs of evaluation of each of the model's nodes - node_proofs: Vec, + pub(crate) node_proofs: Vec>, // Proofs of opening of each of the model's outputs - opening_proofs: Vec, + pub(crate) opening_proofs: Vec, } // TODO change the functions that receive vectors to receive slices instead whenever it makes sense @@ -70,6 +74,10 @@ where } } + pub(crate) fn input_shape(&self) -> &Vec { + &self.input_shape + } + pub(crate) fn setup_keys( &self, rng: &mut R, @@ -128,7 +136,8 @@ where ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, sponge: &mut S, - node_commitments: Vec>, + node_coms: &Vec>, + node_com_states: &Vec>, input: QArray, ) -> InferenceProof { // TODO Absorb public parameters into s (to be determined what exactly) @@ -141,7 +150,7 @@ where 0, ); - let output_f = output.values().iter().map(|x| F::from(*x)).collect(); + let output_f: Vec = output.values().iter().map(|x| F::from(*x)).collect(); let mut output = QTypeArray::S(output); @@ -149,7 +158,10 @@ where // TODO handling F and QSmallType is inelegant; we might want to switch // to F for IO in NodeOps::prove let mut node_outputs = vec![output.clone()]; - let mut node_outputs_f = vec![output_f]; + let mut node_output_mles = vec![Poly::from_evaluations_vec( + log2(output_f.len()) as usize, + output_f, + )]; for node in &self.nodes { output = node.padded_evaluate(&output); @@ -160,27 +172,31 @@ where }; node_outputs.push(output.clone()); - node_outputs_f.push(output_f); + node_output_mles.push(Poly::from_evaluations_vec( + log2(output_f.len()) as usize, + output_f, + )); } // Committing to node outputs as MLEs (individual per node for now) - let output_mles: Vec>> = node_outputs_f + let labeled_output_mles: Vec>> = node_output_mles .iter() - .map(|values| + .map(|mle| // TODO change dummy label once we e.g. have given numbers to the // nodes in the model: fc_1, fc_2, relu_1, etc. + // TODO maybe we don't need to clone, if `prove` can take a reference LabeledPolynomial::new( "dummy".to_string(), - Poly::from_evaluations_vec(log2(values.len()) as usize, values.clone()), + mle.clone(), None, None, )) .collect(); - let (node_coms, node_com_states) = PCS::commit(ck, &output_mles, rng).unwrap(); + let (output_coms, output_com_states) = PCS::commit(ck, &labeled_output_mles, rng).unwrap(); // Absorb all commitments into the sponge - sponge.absorb(&node_coms); + sponge.absorb(&output_coms); // TODO Prove that all commited NIOs live in the right range (to be // discussed) @@ -188,22 +204,26 @@ where let mut node_proofs = Vec::new(); // Second pass: proving - for ((((node, node_com), values), l_v_coms), v_coms_states) in self + for (((((node, node_com), node_com_state), values), l_v_coms), v_coms_states) in self .nodes .iter() - .zip(node_commitments.iter()) - .zip(node_outputs.windows(2)) - .zip(node_coms.windows(2)) - .zip(node_com_states.windows(2)) + .zip(node_coms.iter()) + .zip(node_com_states.iter()) + .zip(labeled_output_mles.windows(2)) + .zip(output_coms.windows(2)) + .zip(output_com_states.windows(2)) { - // TODO prove likely needs to receive the sponge for randomness/FS node_proofs.push(node.prove( + ck, sponge, - node_com, - values[0].clone(), - l_v_coms[0].commitment(), - values[1].clone(), - l_v_coms[1].commitment(), + &node_com, + &node_com_state, + &values[0], + &l_v_coms[0], + &v_coms_states[0], + &values[1], + &l_v_coms[1], + &v_coms_states[1], )); } @@ -212,21 +232,21 @@ where // output nodes and instead working witht their plain values all along, // but that would require messy node-by-node handling let input_node = node_outputs.first().unwrap(); - let input_node_f = node_outputs_f.first().unwrap(); - let input_labeled_value = output_mles.first().unwrap(); - let input_node_com = node_coms.first().unwrap(); - let input_node_com_state = node_com_states.first().unwrap(); + let input_node_f = node_output_mles.first().unwrap().to_evaluations(); + let input_labeled_value = labeled_output_mles.first().unwrap(); + let input_node_com = output_coms.first().unwrap(); + let input_node_com_state = output_com_states.first().unwrap(); let output_node = node_outputs.last().unwrap(); - let output_node_f = node_outputs_f.last().unwrap(); - let output_labeled_value = output_mles.last().unwrap(); - let output_node_com = node_coms.last().unwrap(); - let output_node_com_state = node_com_states.last().unwrap(); + let output_node_f = node_output_mles.last().unwrap().to_evaluations(); + let output_labeled_value = labeled_output_mles.last().unwrap(); + let output_node_com = output_coms.last().unwrap(); + let output_node_com_state = output_com_states.last().unwrap(); // Absorb the model IO output and squeeze the challenge point // Absorb the plain output and squeeze the challenge point - sponge.absorb(input_node_f); - sponge.absorb(output_node_f); + sponge.absorb(&input_node_f); + sponge.absorb(&output_node_f); let input_challenge_point = sponge.squeeze_field_elements(log2(input_node_f.len()) as usize); let output_challenge_point = @@ -260,10 +280,10 @@ where ) .unwrap(); - /* TODO (important) Change output_node to all boundary nodes: first and last */ // TODO prove that inputs match input commitments? InferenceProof { - outputs: vec![input_node.clone(), output_node.clone()], + inputs_outputs: vec![input_node.clone(), output_node.clone()], + node_value_commitments: output_coms, node_proofs, opening_proofs: vec![input_opening_proof, output_opening_proof], } diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index b798ecc..cd7bcc2 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -1,13 +1,18 @@ +use ark_std::rc::Rc; + +use ark_poly::{MultilinearExtension, Polynomial}; use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_std::rand::RngCore; +use ark_sumcheck::ml_sumcheck::protocol::ListOfProductsOfPolynomials; +use ark_sumcheck::ml_sumcheck::{MLSumcheck, Proof}; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::{BMMQInfo, QInfo, QLargeType, QScaleType, QSmallType}; use crate::{Commitment, CommitmentState}; @@ -35,14 +40,16 @@ pub(crate) struct BMMNode { phantom: PhantomData<(F, S, PCS)>, } +/// Commitment to a BMM node, consisting of a commitment to the *dual* of the +/// weight MLE and one to the *dual* of the bias MLE pub(crate) struct BMMNodeCommitment where F: PrimeField, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - weight_com: PCS::Commitment, - bias_com: PCS::Commitment, + pub(crate) weight_com: LabeledCommitment, + pub(crate) bias_com: LabeledCommitment, } impl Commitment for BMMNodeCommitment @@ -53,6 +60,8 @@ where { } +/// Commitment states associated to a BMMNodeCommitment: one for the weight and +/// one for the bias pub(crate) struct BMMNodeCommitmentState where F: PrimeField, @@ -71,8 +80,36 @@ where { } -pub(crate) struct BMMNodeProof { - // this will be the sumcheck proof +/// Proof of execution of a BMM node, consisting of a sumcheck proof and four +/// PCS opening proofs +pub(crate) struct BMMNodeProof< + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +> { + /// Sumcheck protocol proof for the polynomial + /// g(x) = (input - zero_point)^(x) * W^(r, x), + /// where v^ denotes the dual of the MLE of v and r is a challenge point + pub(crate) sumcheck_proof: Proof, + + /// Value of the *dual* of the input MLE at the challenge point s and proof + /// of opening + pub(crate) input_opening_proof: PCS::Proof, + pub(crate) input_opening_value: F, + + /// Value of the *dual* of the weight MLE at the challenge point r || s and proof of + /// opening + pub(crate) weight_opening_proof: PCS::Proof, + pub(crate) weight_opening_value: F, + + /// Proof of opening of the *duals* of the output and bias MLEs at the + // challenge point + pub(crate) output_bias_opening_proof: PCS::Proof, + + /// Value of the *dual* of the weight MLE at the challenge point and proof of + /// opening + pub(crate) output_opening_value: F, + pub(crate) bias_opening_value: F, } impl NodeOps for BMMNode @@ -133,7 +170,7 @@ where impl NodeOpsSNARK for BMMNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -203,14 +240,14 @@ where rng: Option<&mut dyn RngCore>, ) -> (NodeCommitment, NodeCommitmentState) { // TODO should we separate the associated commitment type into one with state and one without? - - let num_vars_weights = self.padded_dims_log.0 + self.padded_dims_log.1; let padded_weights_f: Vec = self.padded_weights.iter().map(|w| F::from(*w)).collect(); + // TODO part of this code is duplicated in prove, another hint that this should probs + // be stored let weight_poly = LabeledPolynomial::new( "weight_poly".to_string(), - Poly::from_evaluations_vec(num_vars_weights, padded_weights_f), - None, + Poly::from_evaluations_vec(self.com_num_vars(), padded_weights_f), + Some(1), None, ); @@ -219,7 +256,7 @@ where let bias_poly = LabeledPolynomial::new( "bias_poly".to_string(), Poly::from_evaluations_vec(self.padded_dims_log.1, padded_bias_f), - None, + Some(1), None, ); @@ -227,8 +264,8 @@ where ( NodeCommitment::BMM(BMMNodeCommitment { - weight_com: coms.0[0].commitment().clone(), - bias_com: coms.0[1].commitment().clone(), + weight_com: coms.0[0].clone(), + bias_com: coms.0[1].clone(), }), NodeCommitmentState::BMM(BMMNodeCommitmentState { weight_com_state: coms.1[0].clone(), @@ -239,14 +276,153 @@ where fn prove( &self, - s: &mut S, + ck: &PCS::CommitterKey, + sponge: &mut S, node_com: &NodeCommitment, - input: QTypeArray, - input_com: &PCS::Commitment, - output: QTypeArray, - output_com: &PCS::Commitment, - ) -> NodeProof { - unimplemented!() + node_com_state: &NodeCommitmentState, + input: &LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: &PCS::CommitmentState, + ) -> NodeProof { + let (weight_com, bias_com) = match node_com { + NodeCommitment::BMM(BMMNodeCommitment { + weight_com, + bias_com, + }) => (weight_com, bias_com), + _ => panic!("BMMNode::prove expected node commitment of type BMMNodeCommitment"), + }; + + let (weight_com_state, bias_com_state) = match node_com_state { + NodeCommitmentState::BMM(BMMNodeCommitmentState { + weight_com_state, + bias_com_state, + }) => (weight_com_state, bias_com_state), + _ => panic!( + "BMMNode::prove expected node commitment state of type BMMNodeCommitmentState" + ), + }; + + // We can squeeze directly, since the sponge has already absorbed all the + // commitments in Model::prove_inference + let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); + + let i_z_p_f = F::from(self.input_zero_point); + + /// (f - zero-point)^ + let shifted_input_mle = Poly::from_evaluations_vec( + input.num_vars(), + input.polynomial().iter().map(|x| *x - i_z_p_f).collect(), + ); + + // TODO consider whether this can be done once and stored + let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); + + // Dual of the MLE of the row-major flattening of the weight matrix + let weight_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); + + // TODO consider whether this can be done once and stored + let bias_f = self.padded_bias.iter().map(|w| F::from(*w)).collect(); + // Dual of the MLE of the bias vector + let bias_mle = Poly::from_evaluations_vec(self.padded_dims_log.1, bias_f); + + let bias_opening_value = bias_mle.evaluate(&r); + let output_opening_value = output.evaluate(&r); + + // Constructing the sumcheck polynomial + // g(x) = (input - zero_point)^(x) * W^(r, x), + let bound_weight_mle = weight_mle.fix_variables(&r); + let mut g = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); + + // TODO we are cloning the input here, can we do better? + g.add_product( + vec![shifted_input_mle, bound_weight_mle] + .into_iter() + .map(Rc::new) + .collect::>(), + F::one(), + ); + + let (sumcheck_proof, prover_state) = + MLSumcheck::::prove_as_subprotocol(&g, sponge).unwrap(); + + // The prover computes the claimed evaluations of weight_mle and + // input_mle at the random challenge point + // s := prover_state.randomness, the list of random values sampled by + // the verifier during sumcheck. Note that this is different from r + // above. + // + // We need to reveal g(s) by opening input^ at s and weight^ at s || r; + // and also open output^ and bias^ at r + let claimed_evaluations: Vec = g + .flattened_ml_extensions + .iter() + .map(|x| x.evaluate(&prover_state.randomness)) + .collect(); + + // Recall that the first factor of g was the *shifted* dual input + // (input - zero_point)^ + let input_opening_value = claimed_evaluations[0] + i_z_p_f; + let weight_opening_value = claimed_evaluations[1]; + + let input_opening_proof = PCS::open( + &ck, + [input], + [input_com], + &prover_state.randomness, + sponge, + [input_com_state], + None, + ) + .unwrap(); + + let weight_opening_proof = PCS::open( + &ck, + [&LabeledPolynomial::new( + "weight_mle".to_string(), + weight_mle, + Some(1), + None, + )], + [weight_com], + &r.clone() + .into_iter() + .chain(prover_state.randomness) + .collect(), + sponge, + [weight_com_state], + None, + ) + .unwrap(); + + // TODO: b and o are opened at the same point, so they could be opened + // with a single call to PCS::open + let output_bias_opening_proof = PCS::open( + &ck, + [ + output, + &LabeledPolynomial::new("bias_mle".to_string(), bias_mle, Some(1), None), + ], + [output_com, bias_com], + &r, + sponge, + [output_com_state, bias_com_state], + None, + ) + .unwrap(); + + NodeProof::BMM(BMMNodeProof { + sumcheck_proof, + input_opening_proof, + input_opening_value, + weight_opening_proof, + weight_opening_value, + output_bias_opening_proof, + output_opening_value, + bias_opening_value, + }) } } @@ -304,6 +480,14 @@ where phantom: PhantomData, } } + + pub(crate) fn padded_dims_log(&self) -> (usize, usize) { + self.padded_dims_log + } + + pub(crate) fn input_zero_point(&self) -> QSmallType { + self.input_zero_point + } } // TODO in constructor, add quantisation information checks? (s_bias = s_input * s_weight, z_bias = 0, z_weight = 0, etc.) // TODO in constructor, check bias length matches appropriate matrix dimension diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index 7b9d83f..e5cc3b3 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -1,5 +1,6 @@ +use ark_crypto_primitives::sponge::Absorb; use ark_ff::PrimeField; -use ark_poly_commit::PolynomialCommitment; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::{ @@ -19,7 +20,10 @@ use self::{ reshape::ReshapeNode, }; -use super::qarray::{QArray, QTypeArray}; +use super::{ + qarray::{QArray, QTypeArray}, + LabeledPoly, +}; pub(crate) mod bmm; pub(crate) mod relu; @@ -51,7 +55,7 @@ pub(crate) trait NodeOps { pub(crate) trait NodeOpsSNARK where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -98,13 +102,17 @@ where /// Produce a node output proof fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, - input_com: &PCS::Commitment, - output: QTypeArray, - output_com: &PCS::Commitment, - ) -> NodeProof; + node_com_state: &NodeCommitmentState, + input: &LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: &PCS::CommitmentState, + ) -> NodeProof; } pub(crate) enum Node @@ -119,8 +127,13 @@ where Reshape(ReshapeNode), } -pub(crate) enum NodeProof { - BMM(BMMNodeProof), +pub(crate) enum NodeProof +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + BMM(BMMNodeProof), RequantiseBMM(RequantiseBMMNodeProof), ReLU(()), Reshape(()), @@ -154,7 +167,7 @@ where // elegantly by simply implementing the trait impl Node where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -191,7 +204,7 @@ where // elegantly by simply implementing the trait impl NodeOps for Node where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -213,7 +226,7 @@ where impl NodeOpsSNARK for Node where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -244,14 +257,28 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, - input_com: &PCS::Commitment, - output: QTypeArray, - output_com: &PCS::Commitment, - ) -> NodeProof { - self.as_node_ops_snark() - .prove(s, node_com, input, input_com, output, output_com) + node_com_state: &NodeCommitmentState, + input: &LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: &PCS::CommitmentState, + ) -> NodeProof { + self.as_node_ops_snark().prove( + ck, + s, + node_com, + node_com_state, + input, + input_com, + input_com_state, + output, + output_com, + output_com_state, + ) } } diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 3e67abf..1cfc44c 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -1,16 +1,16 @@ use ark_std::log2; use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::PolynomialCommitment; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::QSmallType; -use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK}; +use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeProof}; // Rectified linear unit node performing x |-> max(0, x). pub(crate) struct ReLUNode @@ -49,7 +49,7 @@ where // impl NodeOpsSnark impl NodeOpsSNARK for ReLUNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -66,7 +66,7 @@ where ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, ) -> (NodeCommitment, NodeCommitmentState) { - todo!() + (NodeCommitment::ReLU(()), NodeCommitmentState::ReLU(())) } // TODO this is the same as evaluate() for now; the two will likely differ @@ -84,14 +84,18 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, - input_com: &PCS::Commitment, - output: QTypeArray, - output_com: &PCS::Commitment, - ) -> super::NodeProof { - todo!() + node_com_state: &NodeCommitmentState, + input: &LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: &PCS::CommitmentState, + ) -> NodeProof { + NodeProof::ReLU(()) } } diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index a0ce211..319b21a 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -1,13 +1,13 @@ use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::{ requantise_fc, BMMQInfo, QInfo, QLargeType, QScaleType, QSmallType, RoundingScheme, }; @@ -87,7 +87,7 @@ where impl NodeOpsSNARK for RequantiseBMMNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -146,14 +146,18 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, - input_com: &PCS::Commitment, - output: QTypeArray, - output_com: &PCS::Commitment, - ) -> NodeProof { - unimplemented!() + node_com_state: &NodeCommitmentState, + input: &LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: &PCS::CommitmentState, + ) -> NodeProof { + NodeProof::RequantiseBMM(RequantiseBMMNodeProof {}) } } diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 6e7d091..04bf966 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -1,13 +1,13 @@ use ark_std::log2; use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::PolynomialCommitment; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, NodeCommitmentState, Poly}; use crate::quantization::QSmallType; use super::{NodeCommitment, NodeOps, NodeOpsSNARK, NodeProof}; @@ -59,7 +59,7 @@ where impl NodeOpsSNARK for ReshapeNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -112,23 +112,27 @@ where &self, ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, - ) -> ( - super::NodeCommitment, - super::NodeCommitmentState, - ) { - todo!() + ) -> (NodeCommitment, NodeCommitmentState) { + ( + NodeCommitment::Reshape(()), + NodeCommitmentState::Reshape(()), + ) } fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, - input_com: &PCS::Commitment, - output: QTypeArray, - output_com: &PCS::Commitment, - ) -> NodeProof { - unimplemented!() + node_com_state: &NodeCommitmentState, + input: &LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: &PCS::CommitmentState, + ) -> NodeProof { + NodeProof::Reshape(()) } } diff --git a/src/model/reshaping.rs b/src/model/reshaping.rs deleted file mode 100644 index beea6a6..0000000 --- a/src/model/reshaping.rs +++ /dev/null @@ -1,40 +0,0 @@ -use ark_std::vec; - -// Let `array` be an array of length m. Define M = 2^(ceil(max(log2(m), 0))) -// This function pads `array` to length M with the value `pad`. -pub(crate) fn pad_pow2_1d(mut array: Vec, pad: T) -> Vec { - let m = array.len().next_power_of_two(); - array.resize(m, pad); - array -} - -// Let `array` be a non-empty array of subarrays. Let m = array.len() and -// n = array[0].len(). Define M = 2^(ceil(max(log2(m), 0))) and -// N = 2^(ceil(max(log2(n), 0))). -// This function pads (with the value `pad`) or truncates each subarray of -// `array` to length N; and also pads `array` itself to length M with -// new subarrays of length N filled with the value `pad`. -// -// Panics if `array` is empty -pub(crate) fn pad_pow2_2d(array: Vec>, pad: T) -> Vec> { - assert!(array.is_empty()); - - let m_0 = array.len(); - let m = m_0.next_power_of_two(); - - let n = array[0].len().next_power_of_two(); - - let mut padded_array = Vec::with_capacity(m); - - for subarray in array { - let mut s = subarray.clone(); - s.resize(n, pad); - padded_array.push(s); - } - - for _ in 0..(m - m_0) { - padded_array.push(vec![pad; n]); - } - - padded_array -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..15a05fb --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,8 @@ +#[cfg(test)] +pub(crate) mod pcs_types; + +#[cfg(test)] +pub(crate) mod test_sponge; + +use ark_ff::Field; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; diff --git a/src/pcs_types.rs b/src/utils/pcs_types.rs similarity index 67% rename from src/pcs_types.rs rename to src/utils/pcs_types.rs index 727e397..82414c9 100644 --- a/src/pcs_types.rs +++ b/src/utils/pcs_types.rs @@ -3,10 +3,14 @@ use ark_crypto_primitives::{ merkle_tree::{ByteDigestConverter, Config}, sponge::poseidon::PoseidonSponge, }; +// no-std note: +// Currently, we use the `LeafIdentityHasher` from ark_pcs_bench_templates. +// This is not ideal, since the entire `ark_pcs_bench_templates` crate does not support `no_std` +// (due to `criterion`) dependency. use ark_pcs_bench_templates::*; use ark_poly::DenseMultilinearExtension; -use ark_poly_commit::linear_codes::{LinearCodePCS, MultilinearBrakedown}; +use ark_poly_commit::linear_codes::{LinearCodePCS, MultilinearLigero}; use blake2::Blake2s256; // Brakedown PCS over BN254 @@ -26,14 +30,9 @@ impl Config for MerkleTreeParams { type MTConfig = MerkleTreeParams; type ColHasher = FieldToBytesColHasher; -pub(crate) type Brakedown = LinearCodePCS< - MultilinearBrakedown< - F, - MTConfig, - PoseidonSponge, - DenseMultilinearExtension, - ColHasher, - >, + +pub(crate) type Ligero = LinearCodePCS< + MultilinearLigero, DenseMultilinearExtension, ColHasher>, F, DenseMultilinearExtension, PoseidonSponge, diff --git a/src/utils/test_sponge.rs b/src/utils/test_sponge.rs new file mode 100644 index 0000000..eab3854 --- /dev/null +++ b/src/utils/test_sponge.rs @@ -0,0 +1,39 @@ +use ark_crypto_primitives::sponge::{ + poseidon::{PoseidonConfig, PoseidonSponge}, + CryptographicSponge, +}; +use ark_ff::PrimeField; +use ark_std::test_rng; + +pub(crate) fn test_sponge() -> PoseidonSponge { + PoseidonSponge::new(&poseidon_parameters_for_test()) +} + +/// Generate default parameters for alpha = 17, state-size = 8 +/// +/// WARNING: This poseidon parameter is not secure. Please generate +/// your own parameters according the field you use. +fn poseidon_parameters_for_test() -> PoseidonConfig { + let full_rounds = 8; + let partial_rounds = 31; + let alpha = 17; + + let mds = vec![ + vec![F::one(), F::zero(), F::one()], + vec![F::one(), F::one(), F::zero()], + vec![F::zero(), F::one(), F::one()], + ]; + + let mut ark = Vec::new(); + let mut ark_rng = test_rng(); + + for _ in 0..(full_rounds + partial_rounds) { + let mut res = Vec::new(); + + for _ in 0..3 { + res.push(F::rand(&mut ark_rng)); + } + ark.push(res); + } + PoseidonConfig::new(full_rounds, partial_rounds, alpha, mds, ark, 2, 1) +}