-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: auto split onnx model (https://github.com/zkonduit/ezkl/discussions/744) #855
base: main
Are you sure you want to change the base?
Conversation
@ExcellentHH thanks so much for getting this over the finish line. Will review this in a bit but first off huge congratulations on seeing this through 🎉🎉🎉🎉🍾🍾🍾🍾🍾 |
Great stuff indeed! I think we're going to want to change to KZG commitments rather than hashes to save rows, and add a final loop to actually compute the proofs, glue, and verify in an integration test. |
Amazing work! One usability caveat that might potentially cause issues are networks with recurrent structures. The following line should deal fine with DAG type of networks
For better usability it might be worth flagging cycles in the network to users, and provide an error message saying the scheme will not support such kinds of networks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great start for this. I've added some comments.
It would also be great to add an integration test for all this that splits a smaller model like nanoGPT into 2^17 chunks, generates the witness for each one, proves all the chunks, then verifies that the proof commitments match using the method from the example notebooks :)
Would also note that we need to ensure that the input_scale for subgraphs with index > 0 is the same as the output scale of the prior subgraph. You can see how we do this in the example notebooks. Rn fwiw it is calibrating over all scales. This will also shorten runtime significantly.
There are some other subtleties to discuss but lets resolve these issues first. If you don't have bandwidth to solve all of these lmk and I can help you get it over the line
help="Input shape for the ONNX model in JSON format. Default is '{\"input\": [1, 3, 224, 224]}'.") | ||
parser.add_argument("--simplify", action='store_true', | ||
help="Flag to indicate if the model should be simplified. Default is False.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add an argument to specify the visibility of the stitched commitments eg. public
, hashed
, polycommit
parser.add_argument("--simplify", action='store_true', | ||
help="Flag to indicate if the model should be simplified. Default is False.") | ||
|
||
args = parser.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a flag to (optionally) compile the circuit (or not) -- but if the upper_bound_per_subgraph
is 23 logrows for eg. it should 1. calibrate 2. reduce the logrows to 23 manually 3. compile
is_pass = False | ||
with open(json_file, 'r') as f: | ||
data = json.load(f) | ||
total_assignments = data.get("total_assignments", 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be checking num_rows
not total_assignements
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Process ONNX model and generate subgraphs.") | ||
parser.add_argument("--onnx_model_path", type=str, default='./resnet18.onnx', help="Path to the ONNX model. Default is './resnet18.onnx'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change the default to match the library's default of network.onnx
res = ezkl.gen_settings(temp_model_name, py_run_args=run_args) | ||
assert res == True | ||
|
||
data_path = f"input_data_{subgraph_index}.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something I realized would be super helpful to make this more useable, is to create a separate script that takes the partitioned onnx
files and generates a new set of intermediate / subgraph input json files given an orginal file.
So for eg.
If I have input.json
and subgraph1.onnx
,subgraph2.onnx
, subgraph3.onnx
the script generate_subgraph_input.json
should yield input files for each subgraph:
input_data_1.json
, input_data_2.json
, input_data_3.json
.
Even if input_data_1.json
= input.json
would still be useful to make it explicit.
Just came to this realization cause I ended up having to code this up and would be useful
@alexander-camuto @jasonmorton @JSeam2 Thank you for the response and encouragement. Special thanks to @alexander-camuto for the improvement suggestions—I’ve learned a lot from them, and I apologize for the immaturity of my code. I’ve been a bit busy recently, but I’ll work on the code improvements as soon as possible. Thanks again! |
A simple and naive solution to a discussion topic (#744) I raised.
Script Explanation
This script addresses the challenge of generating correctness proofs for large ONNX models on machines with limited hardware capabilities. For instance, if a machine has a processing constraint of (2^{24}) but the model requires significantly more, it becomes difficult to handle.
To overcome this, the script automatically partitions a large model into multiple smaller sub-models based on a given upper threshold. It ensures that the intermediate results between sub-models are protected through hashing for privacy.
By splitting the large model, this approach enables verification of larger models on machines with average hardware. Additionally, it facilitates parallel validation of the models by allowing multiple sub-models to be validated simultaneously using multithreading or multiple machines, thus improving overall efficiency.