You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use hls4ml.utils.config_from_keras_model() to get the config from my keras model, and I get an error IndexError: list index out of range.
Details
I created a simple perceptron in keras, with pretrained weights and biases. Then, when trying to get the config for HLS4ML an error is produced.
Steps to Reproduce
This is the python script to generate the keras model:
importnumpyasnpimporttensorflowastffromtensorflow.kerasimportModel, Inputfromtensorflow.keras.layersimportDense# Neural network constantsgain1=0.000455892409391384ymin1=-1b1=-0.15387008833216031523IW1=np.array([0.1125322203431925322, -0.34492392384017417362, 0.7052591999722942484,
-1.2719945400779701927, 2.089129767777872182, -1.2416911002851198642,
0.53266405909136382846, -0.2257313111056203081, 0.076208495617127072763])
b2=-0.34191592641085533089LW2=1.2490161057030868541ymin2=-1gain2=2165.98876# Build the modeldefbuild_model():
# Input layerinput_layer=Input(shape=(len(IW1),), name="Input")
# Layer 1 (Dense with custom weights and bias, using tanh activation)dense1=Dense(
units=1,
activation="tanh",
use_bias=True,
kernel_initializer=tf.constant_initializer(IW1.reshape(-1, 1)),
bias_initializer=tf.constant_initializer(b1),
name="Hidden",
)(input_layer)
# Layer 2 (Dense with single weight and bias)output_layer=Dense(
units=1,
activation="linear",
use_bias=True,
kernel_initializer=tf.constant_initializer(LW2),
bias_initializer=tf.constant_initializer(b2),
name="Output",
)(dense1)
# Create the modelmodel=Model(inputs=input_layer, outputs=output_layer, name="Perceptron")
returnmodel# Instantiate and compile the modelmodel=build_model()
model.compile(optimizer="adam", loss="mse")
# Summary of the modelmodel.summary()
# Save the model to a .keras filemodel.save("./perceptron.keras")
This is the python script for using HLS4ML:
importkerasimporthls4mlimportjsonmodel_path='perceptron.keras'# Load the model from the .keras filekeras_model=keras.models.load_model(model_path)
# Generate a simple configuration from keras modelconfig=hls4ml.utils.config_from_keras_model(keras_model, granularity='name')
# Convert to an hls modelhls_model=hls4ml.converters.convert_from_keras_model(keras_model, hls_config=config, output_dir='.')
Expected behavior
The HLS4ML script should read the config and convert the keras model to an HLS model.
Actual behavior
The HLS4ML script can't pass from the config stage.
Error
2024-11-20 09:20:22.246408: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-20 09:20:22.246749: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-20 09:20:22.248817: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-20 09:20:22.254286: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732090822.263536 88589 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732090822.266125 88589 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-20 09:20:22.275680: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARN: Unable to import optimizer(s) from expr_templates.py: No module named 'sympy'
/home/fran/.local/lib/python3.10/site-packages/hls4ml/converters/__init__.py:27: UserWarning: WARNING: Pytorch converter is not enabled!
warnings.warn("WARNING: Pytorch converter is not enabled!", stacklevel=1)
WARNING: Failed to import handlers from core.py: No module named 'torch'.
WARNING: Failed to import handlers from convolution.py: No module named 'torch'.
WARNING: Failed to import handlers from merge.py: No module named 'torch'.
WARNING: Failed to import handlers from pooling.py: No module named 'torch'.
WARNING: Failed to import handlers from reshape.py: No module named 'torch'.
W0000 00:00:1732090823.476630 88589 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
/home/fran/.local/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 10 variables whereas the saved optimizer has 2 variables.
saveable.load_own_variables(weights_store.get(inner_path))
Interpreting Model
Topology:
Traceback (most recent call last):
File "/home/fran/work_repository/machine_learning_fpga/python/issue_github/hls4ml_dev.py", line 11, in<module>
config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name')
File "/home/fran/.local/lib/python3.10/site-packages/hls4ml/utils/config.py", line 159, in config_from_keras_model
layer_list, _, _, _ = hls4ml.converters.p
arse_keras_model(model_arch, reader)
File "/home/fran/.local/lib/python3.10/site-packages/hls4ml/converters/keras_to_hls.py", line 260, in parse_keras_model
input_shapes = [output_shapes[inbound_node[0]] forinbound_nodein keras_layer['inbound_nodes'][0]]
IndexError: list index out of range
The serialized format of Keras models changed in recent versions, leading to this error. You have two options, use older TF and Keras, or use main branch of hls4ml (soon to be 1.0.0 release)
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
Quick summary
I am trying to use
hls4ml.utils.config_from_keras_model()
to get the config from my keras model, and I get an errorIndexError: list index out of range
.Details
I created a simple perceptron in keras, with pretrained weights and biases. Then, when trying to get the config for HLS4ML an error is produced.
Steps to Reproduce
This is the python script to generate the keras model:
This is the python script for using HLS4ML:
Expected behavior
The HLS4ML script should read the config and convert the keras model to an HLS model.
Actual behavior
The HLS4ML script can't pass from the config stage.
Error
Scripts and keras model files
issue_github.zip
The text was updated successfully, but these errors were encountered: