Skip to content

Commit

Permalink
Merge branch 'fastmachinelearning:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dgburnette committed Jun 4, 2024
2 parents f8a07f1 + b6855fe commit 7ff4cd1
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte

repos:
- repo: https://github.com/psf/black
rev: 24.4.0
rev: 24.4.2
hooks:
- id: black
language_version: python3
Expand Down
4 changes: 4 additions & 0 deletions docs/advanced/extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ For concreteness, let's say our custom layer ``KReverse`` is implemented in Kera
def call(self, inputs):
return tf.reverse(inputs, axis=[-1])
def get_config(self):
return super().get_config()
Make sure you define a ``get_config()`` method for your custom layer as this is needed for correct parsing.
We can define the equivalent layer in hls4ml ``HReverse``, which inherits from ``hls4ml.model.layers.Layer``.

.. code-block:: Python
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/writer/catapult_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def write_yml(self, model):
"""

def keras_model_representer(dumper, keras_model):
model_path = model.config.get_output_dir() + '/keras_model.h5'
model_path = model.config.get_output_dir() + '/keras_model.keras'
keras_model.save(model_path)
return dumper.represent_scalar('!keras_model', model_path)

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def write_yml(self, model):
"""

def keras_model_representer(dumper, keras_model):
model_path = model.config.get_output_dir() + '/keras_model.h5'
model_path = model.config.get_output_dir() + '/keras_model.keras'
keras_model.save(model_path)
return dumper.represent_scalar('!keras_model', model_path)

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/writer/vivado_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def write_yml(self, model):
"""

def keras_model_representer(dumper, keras_model):
model_path = model.config.get_output_dir() + '/keras_model.h5'
model_path = model.config.get_output_dir() + '/keras_model.keras'
keras_model.save(model_path)
return dumper.represent_scalar('!keras_model', model_path)

Expand Down
3 changes: 2 additions & 1 deletion test/pytest/ci-template.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
.pytest:
stage: test
image: gitlab-registry.cern.ch/fastmachinelearning/hls4ml-testing:0.4.base
image: gitlab-registry.cern.ch/fastmachinelearning/hls4ml-testing:0.5.6.base
tags:
- k8s-default
before_script:
- source ~/.bashrc
- git config --global --add safe.directory /builds/fastmachinelearning/hls4ml
- git submodule update --init --recursive hls4ml/templates/catapult/
- if [ $EXAMPLEMODEL == 1 ]; then git submodule update --init example-models; fi
- conda activate hls4ml-testing
Expand Down
4 changes: 4 additions & 0 deletions test/pytest/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def __init__(self):
def call(self, inputs):
return tf.reverse(inputs, axis=[-1])

def get_config(self):
# Breaks serialization and parsing in hls4ml if not defined
return super().get_config()


# hls4ml layer implementation
class HReverse(hls4ml.model.layers.Layer):
Expand Down
2 changes: 1 addition & 1 deletion test/pytest/test_weight_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def test_weight_writer(k, i, f):
print(w_paths[0])
assert len(w_paths) == 1
w_loaded = np.loadtxt(w_paths[0], delimiter=',').reshape(1, 1)
print(f'{w[0,0]:.14}', f'{w_loaded[0,0]:.14}')
print(f'{w[0, 0]:.14}', f'{w_loaded[0, 0]:.14}')
assert np.all(w == w_loaded)

0 comments on commit 7ff4cd1

Please sign in to comment.