Compiling PyTorch Models to Calyx #2056
evanmwilliams
started this conversation in
General
Replies: 1 comment 2 replies
-
Wow wow! This is all really awesome stuff. I'm excited to dig into this deeper but a couple of prodding questions:
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Overview
Hi folks! This post is about my M.Eng project from this past semester. I've been working with @jiahanxie353 to compile PyTorch programs into Calyx (and then Verilog so we can run them on real hardware). We've successfully lowered a basic feed-forward neural-network with one hidden layer and a ReLU activation function written in PyTorch into Calyx, in addition to running it on an FPGA. In the rest of this post, I'll describe what we've done so far to make this pipeline work and what remains left in the future. This post extends the discussion here.
Architecture and Implementation
We implemented a pipeline that compiles a simple feed-forward neural-network from PyTorch down into Calyx. Doing this required quite a few tools. First, we used Allo to lower from PyTorch into MLIR. Allo is an accelerator design language (ADL) developed by Professor Zhang's research group which takes machine learning models developed in Python and lowers them directly into MLIR (and various MLIR-dialects, such as Tensor, Linalg, etc). From there, we have multiple passes in MLIR and we use CIRCT to lower into Calyx. From there, we can emit SystemVerilog directly and run the model on an FPGA. Here is a diagram describing the architecture:
To run the generated Verilog designs on an FPGA we used Vivado, AMD's tool for simulating and running FPGA designs. The PyTorch model we were compiling was quite simple. Here is the code:
Running this through the pipeline took a fair bit of modification to the tools we were using. First, both CIRCT and Calyx did not support floating-point operations. We implemented floating-point operations by using Berkeley's HardFloat library, which contains floating-point arithmetic modules written in Chisel. Chisel can emit Verilog directly, so we simply wrapped the Verilog modules in Calyx constructs and imported them into our Calyx designs.
This caused quite a few issues with the way Calyx manages dependencies and external modules. We resolved this using Morty, a Rust tool that stitches Verilog files together into one giant file that contains all of the dependencies. Invoking Morty required a bit of change in the Calyx backend. Morty takes as input a JSON file that describes all of the dependencies, e.g. something like this:
Now, instead of searching through the dependencies and trying to stitch them together, the Calyx backend will iterate through the external dependencies and create a JSON file. Then, it invokes Morty and produces the Verilog file with all of the necessary hardware constructs. The code for this can be found here.
After adding support for Morty and floating point modules to Calyx, we turned to the modifications needed in CIRCT. There are a few fundamental ways in which MLIR and Calyx are different, and it impacts how the emitted Calyx code needs to be generated. First, MLIR by default has support for "global memories" (things like the weights of the neural network). Calyx does not have such things, and uses external memories instead. The data for these external memories needs to be put in JSON files. We discovered that if the CIRCT compiler saw MLIR code that referenced these global memories, it simply chose not to emit anything. We fixed this now so that CIRCT will read the MLIR global data, generate a JSON file containing the data, and also emit the corresponding lines of Calyx that use the
@external
tag. Doing this required adding JSON support to CIRCT, which may or may not be something the maintainers are okay us pushing upstream. We made a PR that is still being reviewed and refined.Another issue we ran into is that the higher-level dialects of MLIR support multi-dimensional memory accesses, but the lower level dialects (i.e. SCF) do not. In other words, code that looks something like
mem[x][y]
will not pass through the pipeline. This required us to manually iterate over the MLIR AST nodes and flatten the data. Once the data was flattened, it was actually easier to write to the JSON file that was needed by Calyx. In addition to flattening out the data itself, we also had to flatten out the loops that iterated over them. MLIR supports constructs such as nested loops to iterate over multi-dimensional data, but we instead needed to make all of the loops only one loop deep to support further translations. An interesting point is that Calyx will support these multi-dimensional memories (you can just nest arrays in thedata
attribute of some memory), so why we have to flatten it in the intermediate representations isn't immediately clear. Perhaps this is something somebody could look into in future phases of this project.The other major task of our modification with CIRCT was adding support for floating-point add, floating-point multiply, and ReLU (the nonlinearity). Adding support for floating-point operations required creating a new type in CIRCT called
ConstantOp
. The details can be found more concretely in this PR.After all of these changes were implemented, we basically had our pipeline! Running this behemoth of a command on the MLIR file produced by Allo will generate Calyx:
Here is some of the generated Calyx:
I've omitted some of the generated code because the file is quite long. Running this code through Fud and targeting the FPGA using Xilinx tools provides our result!
Discussion and Next Steps
This project was successful in proving the concept that we can indeed translate PyTorch models into Calyx and then Verilog for execution on real hardware. The next steps are to refine the stages of the pipeline. Namely:
The above three items are more for the sake of completion and cleanliness rather than new research areas or discussions. However, there are many more avenues that this project could go down if there is interest within Capra:
Overall I really enjoyed working on the project this semester! Huge shoutout to @jiahanxie353 who served as my mentor and partner for the semester. Thanks to @rachitnigam and @sampsyo for the advice early on in the semester about the design and implementation and for helping with the debugging. I believe Jiahan will be giving a demo of this at the Calyx meeting on Monday May 27, so if you want to learn more I strongly encourage you to attend! Further, feel free to leave your thoughts in this discussion thread about the current work and what can be done in the future.
Beta Was this translation helpful? Give feedback.
All reactions