Skip to content

Commit

Permalink
Packaging changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vloncar committed Aug 27, 2019
1 parent 0cf1344 commit 59f7df6
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 164 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ include README.md
graft example-prjs
graft example-models
graft test
recursive-include hls4ml/hls-templates *
recursive-include hls4ml/templates *
8 changes: 4 additions & 4 deletions example-models/keras-config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
KerasJson: example-keras-model-files/KERAS_3layer.json
KerasH5: example-keras-model-files/KERAS_3layer_weights.h5
#InputData: example-keras-model-files/KERAS_3layer_input_features.dat
#OutputPredictions: example-keras-model-files/KERAS_3layer_predictions.dat
KerasJson: keras/KERAS_3layer.json
KerasH5: keras/KERAS_3layer_weights.h5
#InputData: keras/KERAS_3layer_input_features.dat
#OutputPredictions: keras/KERAS_3layer_predictions.dat
OutputDir: my-hls-test
ProjectName: myproject
XilinxPart: xcku115-flvb2104-2-i
Expand Down
15 changes: 12 additions & 3 deletions example-models/onnx-config.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
OnnxModel: example-onnx-model-files/three_layer_keras.onnx
OnnxModel: onnx/three_layer_keras.onnx
#InputData: keras/KERAS_3layer_input_features.dat
#OutputPredictions: keras/KERAS_3layer_predictions.dat
OutputDir: my-hls-test
ProjectName: myproject
XilinxPart: xcku115-flvb2104-2-i
ClockPeriod: 5

IOType: io_parallel # options: io_serial/io_parallel
ReuseFactor: 1
DefaultPrecision: ap_fixed<16,6>
HLSConfig:
Model:
Precision: ap_fixed<16,6>
ReuseFactor: 1
# LayerType:
# Dense:
# ReuseFactor: 2
# Strategy: Resource
# Compression: True
20 changes: 14 additions & 6 deletions example-models/pytorch-config.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#PytorchModel: example-models/two_layer_model.pt
PytorchModel: example-models/three_layer_model.pt
OutputDir: my-hls-dir-3L
PytorchModel: pytorch/three_layer_model.pt
#InputData: keras/KERAS_3layer_input_features.dat
#OutputPredictions: keras/KERAS_3layer_predictions.dat
OutputDir: my-hls-test
ProjectName: myproject
XilinxPart: xc7vx690tffg1927-2
XilinxPart: xcku115-flvb2104-2-i
ClockPeriod: 5

IOType: io_parallel # options: io_serial/io_parallel
ReuseFactor: 1
DefaultPrecision: ap_fixed<18,8>
HLSConfig:
Model:
Precision: ap_fixed<16,6>
ReuseFactor: 1
# LayerType:
# Dense:
# ReuseFactor: 2
# Strategy: Resource
# Compression: True
57 changes: 5 additions & 52 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
from __future__ import print_function
import numpy as np
import h5py
import os
import tarfile
import json
import argparse
import yaml
import sys
from shutil import copyfile
import math

MAXMULT = 4096
from hls4ml.model import HLSModel
from hls4ml.model.optimizer import optimize_model

filedir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0,os.path.join(filedir, "..", "hls-writer"))
from hls_writer import parse_config, write_hls
sys.path.insert(0,os.path.join(filedir, "..", "hls-writer/optimizer"))
from optimizer import optimize_model
from hls_model import HLSModel
MAXMULT = 4096

class KerasDataReader:
def __init__(self, config):
Expand Down Expand Up @@ -49,42 +39,11 @@ def h5_visitor_func(name):

return shape

############################################################################################
## M A I N
############################################################################################
def main():

# Parse command line arguments
parser = argparse.ArgumentParser(description='')
parser.add_argument("-c", action='store', dest='config',
help="Configuration file.")
args = parser.parse_args()
if not args.config: parser.error('A configuration file needs to be specified.')

configDir = os.path.abspath(os.path.dirname(args.config))
yamlConfig = parse_config(args.config)
if not os.path.isabs(yamlConfig['OutputDir']):
yamlConfig['OutputDir'] = os.path.join(configDir, yamlConfig['OutputDir'])
if not os.path.isabs(yamlConfig['KerasH5']):
yamlConfig['KerasH5'] = os.path.join(configDir, yamlConfig['KerasH5'])
if not os.path.isabs(yamlConfig['KerasJson']):
yamlConfig['KerasJson'] = os.path.join(configDir, yamlConfig['KerasJson'])
if 'InputData' in yamlConfig and not os.path.isabs(yamlConfig['InputData']):
yamlConfig['InputData'] = os.path.join(configDir, yamlConfig['InputData'])
if 'OutputPredictions' in yamlConfig and not os.path.isabs(yamlConfig['OutputPredictions']):
yamlConfig['OutputPredictions'] = os.path.join(configDir, yamlConfig['OutputPredictions'])

if not (yamlConfig["IOType"] == "io_parallel" or yamlConfig["IOType"] == "io_serial"):
raise Exception('ERROR: Invalid IO type')
return yamlConfig

def keras_to_hls_model(yamlConfig):
def keras_to_hls(yamlConfig):

######################
## Do translation
######################
if not os.path.isdir("{}/firmware/weights".format(yamlConfig['OutputDir'])):
os.makedirs("{}/firmware/weights".format(yamlConfig['OutputDir']))

#This is a list of dictionaries to hold all the layer info we need to generate HLS
layer_list = []
Expand Down Expand Up @@ -376,14 +335,8 @@ def keras_to_hls_model(yamlConfig):
#################

reader = KerasDataReader(yamlConfig)
print('Creating HLS model')
hls_model = HLSModel(yamlConfig, reader, layer_list, input_layers, output_layers)
optimizers = ['eliminate_linear_activation', 'merge_batch_norm_quantized_tanh', 'quantize_dense_output']
optimize_model(hls_model, optimizers)
#write_hls(hls_model)
return hls_model


if __name__ == "__main__":
yamlConfig = main()
hls_model = keras_to_hls_model(yamlConfig)
write_hls(hls_model)
49 changes: 9 additions & 40 deletions hls4ml/converters/onnx_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
from __future__ import print_function
import numpy as np
import h5py
import os
import tarfile
import json
import argparse
import yaml
import sys
from shutil import copyfile
import math
from onnx import ModelProto, GraphProto, NodeProto, TensorProto
from onnx import optimizer, helper, numpy_helper, shape_inference

MAXMULT = 4096
from hls4ml.writer.vivado_writer import write_hls
from hls4ml.model import HLSModel
from hls4ml.model.optimizer import optimize_model

filedir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0,os.path.join(filedir, "..", "hls-writer"))
from hls_writer import parse_config, write_hls
from hls_model import HLSModel
MAXMULT = 4096

class ONNXDataReader:
def __init__(self, model):
Expand Down Expand Up @@ -126,33 +117,11 @@ def compute_pads_2d(operation, layer):

return pads

############################################################################################
## M A I N
############################################################################################
def main():

# Parse command line arguments
parser = argparse.ArgumentParser(description='')
parser.add_argument("-c", action='store', dest='config',
help="Configuration file.")
args = parser.parse_args()
if not args.config: parser.error('A configuration file needs to be specified.')

configDir = os.path.abspath(os.path.dirname(args.config))
yamlConfig = parse_config(args.config)
if not os.path.isabs(yamlConfig['OutputDir']):
yamlConfig['OutputDir'] = os.path.join(configDir, yamlConfig['OutputDir'])
if not os.path.isabs(yamlConfig['OnnxModel']):
yamlConfig['OnnxModel'] = os.path.join(configDir, yamlConfig['OnnxModel'])

if not (yamlConfig["IOType"] == "io_parallel" or yamlConfig["IOType"] == "io_serial"):
raise Exception('ERROR: Invalid IO type')
def onnx_to_hls(yamlConfig):

######################
## Do translation
######################
if not os.path.isdir("{}/firmware/weights".format(yamlConfig['OutputDir'])):
os.makedirs("{}/firmware/weights".format(yamlConfig['OutputDir']))

#This is a list of dictionaries to hold all the layer info we need to generate HLS
layer_list = []
Expand Down Expand Up @@ -402,8 +371,8 @@ def main():
## Generate HLS
#################

print('Creating HLS model')
hls_model = HLSModel(yamlConfig, reader, layer_list, input_layers, output_layers)
write_hls(hls_model)

if __name__ == "__main__":
main()
optimizers = ['eliminate_linear_activation', 'merge_batch_norm_quantized_tanh', 'quantize_dense_output']
optimize_model(hls_model, optimizers)
return hls_model
2 changes: 1 addition & 1 deletion hls4ml/model/hls_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from collections import OrderedDict

from templates import get_config_template, get_function_template
from .templates import get_config_template, get_function_template

class HLSConfig(object):
def __init__(self, config):
Expand Down
11 changes: 11 additions & 0 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import absolute_import

from .optimizer import OptimizerPass, register_pass, get_optimizer, optimize_model


from .passes.nop import EliminateLinearActivation
from .passes.bn_quant import MergeBatchNormAndQuantizedTanh, QuantizeDenseOutput

register_pass('eliminate_linear_activation', EliminateLinearActivation)
register_pass('merge_batch_norm_quantized_tanh', MergeBatchNormAndQuantizedTanh)
register_pass('quantize_dense_output', QuantizeDenseOutput)
7 changes: 0 additions & 7 deletions hls4ml/model/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,3 @@ def optimize_model(model, passes=None):
break
else:
optimization_done = True

from passes.nop import EliminateLinearActivation
from passes.bn_quant import MergeBatchNormAndQuantizedTanh, QuantizeDenseOutput

register_pass('eliminate_linear_activation', EliminateLinearActivation)
register_pass('merge_batch_norm_quantized_tanh', MergeBatchNormAndQuantizedTanh)
register_pass('quantize_dense_output', QuantizeDenseOutput)
9 changes: 3 additions & 6 deletions hls4ml/model/optimizer/passes/bn_quant.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import numpy as np
import sys
import re

sys.path.insert(0, '../')
from optimizer import OptimizerPass
sys.path.insert(0, '../..')
import hls_model
import templates
from ..optimizer import OptimizerPass
import hls4ml.model.hls_model as hls_model
import hls4ml.model.templates as templates

class BatchNormalizationQuantizedTanh(hls_model.Layer):
''' Merged Batch Normalization and quantized (binary or ternary) Tanh layer.
Expand Down
4 changes: 1 addition & 3 deletions hls4ml/model/optimizer/passes/nop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import sys
sys.path.insert(0, '../')
from optimizer import OptimizerPass
from ..optimizer import OptimizerPass

class EliminateLinearActivation(OptimizerPass):
def match(self, node):
Expand Down
20 changes: 10 additions & 10 deletions hls4ml/templates/vivado/firmware/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
#include <complex>
#include "ap_int.h"
#include "ap_fixed.h"
#include "nnet_dense.h"
#include "nnet_dense_large.h"
#include "nnet_dense_compressed.h"
#include "nnet_conv.h"
#include "nnet_conv2d.h"
#include "nnet_activation.h"
#include "nnet_common.h"
#include "nnet_batchnorm.h"
#include "nnet_pooling.h"
#include "nnet_merge.h"
#include "nnet_utils/nnet_dense.h"
#include "nnet_utils/nnet_dense_large.h"
#include "nnet_utils/nnet_dense_compressed.h"
#include "nnet_utils/nnet_conv.h"
#include "nnet_utils/nnet_conv2d.h"
#include "nnet_utils/nnet_activation.h"
#include "nnet_utils/nnet_common.h"
#include "nnet_utils/nnet_batchnorm.h"
#include "nnet_utils/nnet_pooling.h"
#include "nnet_utils/nnet_merge.h"

//hls-fpga-machine-learning insert numbers

Expand Down
1 change: 0 additions & 1 deletion hls4ml/templates/vivado/myproject_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

#include "firmware/parameters.h"
#include "firmware/myproject.h"
#include "nnet_helpers.h"

#define CHECKPOINT 5000

Expand Down
Loading

0 comments on commit 59f7df6

Please sign in to comment.