diff --git a/example-prjs/sigmoid/README.md b/example-prjs/sigmoid/README.md new file mode 100644 index 0000000000..ea730ebe96 --- /dev/null +++ b/example-prjs/sigmoid/README.md @@ -0,0 +1,9 @@ +This is small conv2d 1 layer applying a sigmoid activation function example. To run it: +1. Run the run_catapult.sh script. +2. Remove "#define USE_AC_MATH" in nnet_activation.h. +3. Replace the sigmoid_test.cpp in the my-Catapult-test with the sigmoid_test.cpp a level up (if you would like the testbench to be self-checking). +4. Move tb_input_features.dat and tb_output_predictions.dat to my-Catapult-test/tb_data (if you want two pre-computed examples). +5. Compile: +/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home/bin/g++ -g -std=c++11 -DSC_INCLUDE_DYNAMIC_PROCESSES -Wl,-rpath=/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home/lib,-rpath=/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home/shared/lib ./sigmoid_test.cpp ./firmware/sigmoid.cpp -I/wv/USER/venv/hls4ml/example-prjs/sigmoid/my-Catapult-test -I/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home/shared/include -L/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home/shared/lib -Wl,-Bstatic -lsystemc -Wl,-Bdynamic -lpthread -o /wv/USER/venv/hls4ml/example-prjs/sigmoid/my-Catapult-test/sigmoid + +Note: You can create your own array and get the predictions by editing then running sigmoid.py. diff --git a/example-prjs/sigmoid/catapult.py b/example-prjs/sigmoid/catapult.py new file mode 100644 index 0000000000..79db00adbf --- /dev/null +++ b/example-prjs/sigmoid/catapult.py @@ -0,0 +1,43 @@ + +import hls4ml +# import pprint +import yaml +import numpy as np + +print(hls4ml.__version__) + +with open('config.yml', 'r') as ymlfile: + config = yaml.safe_load(ymlfile) + +# try tweaking the reuse_factor on one layer to get different pipelining +# config['HLSConfig']['LayerName']['fc1']['ReuseFactor'] = 4 + +print('NETWORK') +print(config) + +config['OutputDir'] = 'my-Catapult-test' +config['Backend'] = 'Catapult' +config['IOType'] = 'io_stream' + +config['HLSConfig']['Model']['Strategy'] = 'Latency' +#config['HLSConfig']['Model']['Strategy'] = 'Resource' + +# default threshold is infinity +config['HLSConfig']['Model']['BramFactor'] = np.inf +# set to zero to force all weights onto (external function) interface +config['HLSConfig']['Model']['BramFactor'] = 0 + +print('CURRENT CONFIGURATION') +print('Backend='+config['Backend']) +print('IOType='+config['IOType']) +print('BramFactor={bf}'.format(bf=config['HLSConfig']['Model']['BramFactor'])) + +# pprint.pprint(config) + +#Convert it to a hls project +hls_model = hls4ml.converters.keras_to_hls(config) + +hls_model.build(vsynth=False) + +# URL for this info: https://fastmachinelearning.org/hls4ml/setup/QUICKSTART.html + diff --git a/example-prjs/sigmoid/config.yml b/example-prjs/sigmoid/config.yml new file mode 100644 index 0000000000..42b81117f9 --- /dev/null +++ b/example-prjs/sigmoid/config.yml @@ -0,0 +1,15 @@ +Backend: Vivado +KerasJson: sigmoid.json +KerasH5: sigmoid_weights.h5 +OutputDir: my-Catapult-test +ProjectName: sigmoid +XilinxPart: xcku115-flvb2104-2-i +Part: xcku115-flvb2104-2-i +ClockPeriod: 5 + +IOType : io_parallel +HLSConfig: + Model: + Precision: ap_fixed<16, 6> + ReuseFactor: 1 + Strategy: Latency diff --git a/example-prjs/sigmoid/run_catapult.sh b/example-prjs/sigmoid/run_catapult.sh new file mode 100755 index 0000000000..c7bfff4813 --- /dev/null +++ b/example-prjs/sigmoid/run_catapult.sh @@ -0,0 +1,44 @@ +#! /bin/bash + +# This script runs the Catapult flows to generate the HLS. + +VENV=../../../../venv + +MGC_HOME=/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home +export MGC_HOME + +export PATH=/wv/hlstools/python/python37/bin:$PATH:$XILINX_VIVADO/bin:$MGC_HOME/bin +export LD_LIBRARY_PATH=/wv/hlstools/python/python37/lib:$XILINX_VIVADO/lib/lnx64.o:$MGC_HOME/lib +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +# needed for pytest +export OSTYPE=linux-gnu + +echo "Activating Virtual Environment..." +# bash +source $VENV/bin/activate + +rm -rf ./my-Catapult-test* + +# to run catapult+vivado_rtl +sed -e 's/Vivado/Catapult/g' vivado.py >catapult.py +# to only run catapult +# sed -e 's/Vivado/Catapult/g' vivado.py | sed -e 's/vsynth=True/vsynth=False/g' >catapult.py + +# actually run HLS4ML + Catapult (+ optional vivado RTL) +python3 catapult.py + +# run just the C++ execution +echo "" +echo "=====================================================" +echo "=====================================================" +echo "C++ EXECUTION" +pushd my-Catapult-test; rm -f a.out; $MGC_HOME/bin/g++ -std=c++17 -I. -DWEIGHTS_DIR=\"firmware/weights\" -Ifirmware -I$MGC_HOME/shared/include firmware/sigmoid.cpp sigmoid_test.cpp; a.out; popd + +# Using VSCode setup generated by Catapult +echo "" +echo "=====================================================" +echo "=====================================================" +echo "To launch VSCode on the C++ generated by hls4ml:" +echo "setenv LD_LIBRARY_PATH $MGC_HOME/lib:$MGC_HOME/shared/lib" +echo "pushd my-Catapult-test; /wv/hlstools/vscode/LATEST/code Catapult.code-workspace" diff --git a/example-prjs/sigmoid/run_vivado.sh b/example-prjs/sigmoid/run_vivado.sh new file mode 100755 index 0000000000..280e5fc119 --- /dev/null +++ b/example-prjs/sigmoid/run_vivado.sh @@ -0,0 +1,39 @@ +#! /bin/bash + +# This script runs the Vivado flows to generate the HLS. + +VENV=/wv/scratch-baimar9c/venv + +MGC_HOME=/wv/hlsb/CATAPULT/TOT/CURRENT/aol/Mgc_home +export MGC_HOME + +export PATH=/wv/hlstools/python/python37/bin:$PATH:$XILINX_VIVADO/bin:$MGC_HOME/bin +export LD_LIBRARY_PATH=/wv/hlstools/python/python37/lib:$XILINX_VIVADO/lib/lnx64.o:$MGC_HOME/lib +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +# needed for pytest +export OSTYPE=linux-gnu + +echo "Activating Virtual Environment..." +# bash +source $VENV/bin/activate + +rm -rf ./my-Vivado-test* + +mkdir -p tb_data + +# to run catapult+vivado_rtl +sed -e 's/Vivado/Catapult/g' vivado.py >catapult.py +# to only run catapult +# sed -e 's/Vivado/Catapult/g' vivado.py | sed -e 's/vsynth=True/vsynth=False/g' >catapult.py + +# actually run HLS4ML + Vivado HLS +python vivado.py + +# run just the C++ execution +echo "" +echo "=====================================================" +echo "=====================================================" +echo "C++ EXECUTION" +pushd my-Vivado-test; rm -f a.out; $MGC_HOME/bin/g++ -g -std=c++11 -I. -DWEIGHTS_DIR=\"firmware/weights\" -Ifirmware -Ifirmware/ap_types -I$MGC_HOME/shared/include firmware/sigmoid.cpp sigmoid_test.cpp; a.out; popd + diff --git a/example-prjs/sigmoid/sigmoid.h5 b/example-prjs/sigmoid/sigmoid.h5 new file mode 100644 index 0000000000..acd45108d9 Binary files /dev/null and b/example-prjs/sigmoid/sigmoid.h5 differ diff --git a/example-prjs/sigmoid/sigmoid.json b/example-prjs/sigmoid/sigmoid.json new file mode 100644 index 0000000000..5d41172804 --- /dev/null +++ b/example-prjs/sigmoid/sigmoid.json @@ -0,0 +1 @@ +{"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": [null, 5, 5, 1], "dtype": "float32", "sparse": false, "ragged": false, "name": "conv2d_input"}}, {"class_name": "Conv2D", "config": {"name": "conv2d", "trainable": true, "dtype": "float32", "batch_input_shape": [null, 5, 5, 1], "filters": 1, "kernel_size": [3, 3], "strides": [1, 1], "padding": "valid", "data_format": "channels_last", "dilation_rate": [1, 1], "groups": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": "random_kernel", "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.11.0", "backend": "tensorflow"} \ No newline at end of file diff --git a/example-prjs/sigmoid/sigmoid.py b/example-prjs/sigmoid/sigmoid.py new file mode 100644 index 0000000000..9701489435 --- /dev/null +++ b/example-prjs/sigmoid/sigmoid.py @@ -0,0 +1,112 @@ +import tensorflow as tf +import numpy as np + +# Sources: +# https://www.geeksforgeeks.org/python-tensorflow-tf-keras-layers-conv2d-function/ +# https://jiafulow.github.io/blog/2021/02/17/simple-fully-connected-nn-firmware-using-hls4ml/ +# https://stackoverflow.com/questions/51930312/how-to-include-a-custom-filter-in-a-keras-based-cnn + +# Create custom kernel +# NOTE: This kernel is random and purely for testing small examples +def random_kernel(shape=(3,3,1), dtype=None): + + f = np.array([ + [[[1]], [[-1]], [[1]]], + [[[-1]], [[1]], [[-1]]], + [[[1]], [[-1]], [[1]]] + ]) + assert f.shape == shape + return f + +# Create model with one conv2d layer for small example +def create_model(): + # Create a model + model = tf.keras.Sequential() + + # First layer args.: + # filters: Number of output filters. + # kernel_size: Convolution window size width and height. + # strides: Stride of the convolution. + # padding: "same" adds padding if needed to ensure output dimensions are equal to input dimensions. "valid" means no padding. + # activation: Non-linear functions (i.e. relu). + # use_bias: Boolean or bias vectors. + # dilation_rate: Dilation rate for dilated convolutions. + # kernel_initializer: Default is glorot_uniform, meaning it initializes acrossed an uniform distribution. + # bias_initializer: Initializer for bias vectors. + # kernel_constraint: Constraint function for the kernel. + # bias_constraint: Constraint function for the bias vectors. + + # NOTE: Input size indicates a 5x5 pixel image (matrix) with one channel (i.e. just the red channel from RGB). + # Image (matrix) size is equal to kernel size since this is a very small example. + model.add(tf.keras.layers.Conv2D(1, 3, 1, padding="valid", activation="sigmoid", kernel_initializer=random_kernel, input_shape=(5, 5, 1))) + + return model + +# Save model to forms for hls4ml +def save_model(model, name=None): + # Save as model.h5, model_weights.h5, and model.json + if name is None: + name = model.name + model.save(name + '.h5') + model.save_weights(name + '_weights.h5') + with open(name + '.json', 'w') as outfile: + outfile.write(model.to_json()) + return + +if __name__ == '__main__': + model = create_model() + save_model(model, name='sigmoid') + + # Image Matrix + image_mat = np.array([ + [ [1], [2], [3], [4], [5] ], + [ [6], [7], [8], [9], [10] ], + [ [11], [12], [13], [14], [15] ], + [ [16], [17], [18], [19], [20] ], + [ [21], [22], [23], [24], [25] ] + ]) + + image_mat = image_mat.reshape((1, 5, 5, 1)) + + # Get prediction + prediction = model.predict(image_mat) + print("Image Matrix\n") + print(image_mat) + print("Prediction\n") + print(prediction) + + image_mat2 = np.array([ + [ [1], [2], [3], [4], [5] ], + [ [5], [1], [2], [3], [4] ], + [ [4], [5], [1], [2], [3] ], + [ [3], [4], [5], [1], [2] ], + [ [2], [3], [4], [5], [1] ] + ]) + + image_mat2 = image_mat2.reshape((1, 5, 5, 1)) + + # Get prediction + prediction = model.predict(image_mat2) + print("Image Matrix\n") + print(image_mat2) + print("Prediction\n") + print(prediction) + + image_mat3 = np.array([ + [ [-1], [2], [-3], [4], [-5] ], + [ [5], [-1], [2], [-3], [4] ], + [ [-4], [5], [-1], [2], [-3] ], + [ [3], [-4], [5], [-1], [2] ], + [ [-2], [3], [-4], [5], [-1] ] + ]) + + image_mat3 = image_mat3.reshape((1, 5, 5, 1)) + + # Get prediction + prediction = model.predict(image_mat3) + print("Image Matrix\n") + print(image_mat3) + print("Prediction\n") + print(prediction) + + diff --git a/example-prjs/sigmoid/sigmoid_test.cpp b/example-prjs/sigmoid/sigmoid_test.cpp new file mode 100644 index 0000000000..832a8afb42 --- /dev/null +++ b/example-prjs/sigmoid/sigmoid_test.cpp @@ -0,0 +1,138 @@ +// +// rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks +// +// Copyright (C) 2017 EJ Kreinar +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "firmware/sigmoid.h" +#include "firmware/nnet_utils/nnet_helpers.h" +// #include "firmware/parameters.h" + +#include + +//hls-fpga-machine-learning insert bram +#include "firmware/weights/w2.h" +#include "firmware/weights/b2.h" + +//hls-fpga-machine-learning insert declare weights +model_default_t w2[9]; +model_default_t b2[1]; + +namespace nnet { + bool trace_enabled = true; + std::map *trace_outputs = NULL; + size_t trace_type_size = sizeof(double); +} + +CCS_MAIN(int argc, char *argv[]) +{ + //load input data from text file + std::ifstream fin("tb_data/tb_input_features.dat"); + //load predictions from text file + std::ifstream fpr("tb_data/tb_output_predictions.dat"); + +#ifdef RTL_SIM + std::string RESULTS_LOG = "tb_data/rtl_cosim_results.log"; +#else + std::string RESULTS_LOG = "tb_data/csim_results.log"; +#endif + std::ofstream fout(RESULTS_LOG); + +#ifndef __SYNTHESIS__ + static bool loaded_weights = false; + if (!loaded_weights) { + //hls-fpga-machine-learning insert load weights + nnet::load_weights_from_txt(w2, "w2.txt"); + nnet::load_weights_from_txt(b2, "b2.txt"); + loaded_weights = true; + } +#endif + std::string iline; + std::string pline; + + if (fin.is_open() && fpr.is_open()) { + while ( std::getline(fin,iline) && std::getline (fpr,pline) ) { + char* cstr=const_cast(iline.c_str()); + char* current; + std::vector in; + current=strtok(cstr," "); + while(current!=NULL) { + in.push_back(atof(current)); + current=strtok(NULL," "); + } + cstr=const_cast(pline.c_str()); + std::vector pr; + current=strtok(cstr," "); + while(current!=NULL) { + pr.push_back(atof(current)); + current=strtok(NULL," "); + } +// std::cout << " Input feature map size = " << in.size() << " Output predictions size = " << pr.size() << std::endl; + + //hls-fpga-machine-learning insert data + ac_channel conv2d_input/*("conv2d_input")*/; + nnet::copy_data(in, conv2d_input); + ac_channel layer3_out/*("layer3_out")*/; + + //hls-fpga-machine-learning insert top-level-function + sigmoid(conv2d_input,layer3_out,w2,b2); + + for(int i = 0; i < OUT_HEIGHT_2*OUT_WIDTH_2; i++) + { + if(fabs(pr[i] - (float)layer3_out[i][0].to_double()) > 0.001) + { + std::cout << "FAILURE" << std::endl; + std::cout << "Expected: " << pr[i] << " Actual: " << layer3_out[i][0] << std::endl; + } + } + + //hls-fpga-machine-learning insert tb-output + nnet::print_result(layer3_out, fout); + } + fin.close(); + fpr.close(); + } else { + std::cout << "INFO: Unable to open input/predictions file, using default input." << std::endl; + + //hls-fpga-machine-learning insert zero + ac_channel conv2d_input/*("conv2d_input")*/; + nnet::fill_zero(conv2d_input); + ac_channel layer3_out/*("layer3_out")*/; + + //hls-fpga-machine-learning insert top-level-function + sigmoid(conv2d_input,layer3_out,w2,b2); + + //hls-fpga-machine-learning insert output + + //hls-fpga-machine-learning insert tb-output + nnet::print_result(layer3_out, fout); + + } + + fout.close(); + std::cout << "INFO: Saved inference results to file: " << RESULTS_LOG << std::endl; + + return 0; +} diff --git a/example-prjs/sigmoid/sigmoid_weights.h5 b/example-prjs/sigmoid/sigmoid_weights.h5 new file mode 100644 index 0000000000..0043e82d97 Binary files /dev/null and b/example-prjs/sigmoid/sigmoid_weights.h5 differ diff --git a/example-prjs/sigmoid/tb_input_features.dat b/example-prjs/sigmoid/tb_input_features.dat new file mode 100644 index 0000000000..46b0ce53bb --- /dev/null +++ b/example-prjs/sigmoid/tb_input_features.dat @@ -0,0 +1,3 @@ +1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 +1 2 3 4 5 5 1 2 3 4 4 5 1 2 3 3 4 5 1 2 2 3 4 5 1 +-1 2 -3 4 -5 5 -1 2 -3 4 -4 5 -1 2 -3 3 -4 5 -1 2 -2 3 -4 5 -1 diff --git a/example-prjs/sigmoid/tb_output_predictions.dat b/example-prjs/sigmoid/tb_output_predictions.dat new file mode 100644 index 0000000000..db5e912a10 --- /dev/null +++ b/example-prjs/sigmoid/tb_output_predictions.dat @@ -0,0 +1,3 @@ +0.99908894 0.99966466 0.9998766 0.99999386 0.99999774 0.99999917 1 1 1 +0.01798621 0.99908894 0.95257413 0.9999546 0.01798621 0.99908894 0.26894143 0.9999546 0.01798621 +3.7751347e-11 1.0000000e+00 1.8795287e-12 1.0000000e+00 3.7751347e-11 1.0000000e+00 3.4424771e-14 1.0000000e+00 3.7751344e-11 diff --git a/example-prjs/sigmoid/vivado.py b/example-prjs/sigmoid/vivado.py new file mode 100644 index 0000000000..2665211b27 --- /dev/null +++ b/example-prjs/sigmoid/vivado.py @@ -0,0 +1,43 @@ + +import hls4ml +# import pprint +import yaml +import numpy as np + +print(hls4ml.__version__) + +with open('config.yml', 'r') as ymlfile: + config = yaml.safe_load(ymlfile) + +# try tweaking the reuse_factor on one layer to get different pipelining +# config['HLSConfig']['LayerName']['fc1']['ReuseFactor'] = 4 + +print('NETWORK') +print(config) + +config['OutputDir'] = 'my-Vivado-test' +config['Backend'] = 'Vivado' +config['IOType'] = 'io_stream' + +config['HLSConfig']['Model']['Strategy'] = 'Latency' +#config['HLSConfig']['Model']['Strategy'] = 'Resource' + +# default threshold is infinity +config['HLSConfig']['Model']['BramFactor'] = np.inf +# set to zero to force all weights onto (external function) interface +config['HLSConfig']['Model']['BramFactor'] = 0 + +print('CURRENT CONFIGURATION') +print('Backend='+config['Backend']) +print('IOType='+config['IOType']) +print('BramFactor={bf}'.format(bf=config['HLSConfig']['Model']['BramFactor'])) + +# pprint.pprint(config) + +#Convert it to a hls project +hls_model = hls4ml.converters.keras_to_hls(config) + +hls_model.build(vsynth=False) + +# URL for this info: https://fastmachinelearning.org/hls4ml/setup/QUICKSTART.html + diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h index 0c07cf7032..10efea64a0 100755 --- a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h @@ -128,10 +128,10 @@ void sigmoid(ac_channel &data, ac_channel &res) { #pragma hls_unroll SigmoidPackLoop: for (int j = 0; j < res_T::size; j++) { //#pragma HLS UNROLL - int data_round = in_data[j]*CONFIG_T::table_size/16; - int index = data_round + 8*CONFIG_T::table_size/16; + int data_round = (int)in_data[j].to_double()*(int)CONFIG_T::table_size/16; + int index = data_round + 8*(int)CONFIG_T::table_size/16; if (index < 0) index = 0; - else if (index > CONFIG_T::table_size-1) index = CONFIG_T::table_size-1; + else if (index > CONFIG_T::table_size-1) index = (int)CONFIG_T::table_size-1; out_data[j] = sigmoid_table[index]; }