Skip to content

Commit

Permalink
Adding ac_math implementation for softplus, softsign, elu and selu ac…
Browse files Browse the repository at this point in the history
…tivation functions
  • Loading branch information
Adithya Beemanapalli committed Jul 4, 2023
1 parent 8933490 commit b82e83d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 36 deletions.
49 changes: 31 additions & 18 deletions hls4ml/templates/catapult/nnet_utils/nnet_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@
#include <cmath>
#include "ac_fixed.h"
#include "ac_std_float.h"
#include "ac_math/ac_softmax_pwl.h"
#include "ac_math/ac_tanh_pwl.h"
#include "ac_math/ac_sigmoid_pwl.h"
#include "ac_math/ac_pow_pwl.h"
#include <ac_math/ac_softmax_pwl.h>
#include <ac_math/ac_tanh_pwl.h>
#include <ac_math/ac_sigmoid_pwl.h>
#include <ac_math/ac_pow_pwl.h>
#include <ac_math/ac_elu_pwl.h>
#include <ac_math/ac_selu_pwl.h>
#include <ac_math/ac_softplus_pwl.h>
#include <ac_math/ac_softsign_pwl.h>
#include "nnet_common.h"

namespace nnet {
Expand Down Expand Up @@ -776,7 +780,9 @@ void softplus(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
template<class data_T, class res_T, typename CONFIG_T>
void softplus(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
assert("softplus not implemented in AC Math");
for (int ii=0; ii<CONFIG_T::n_in; ii++) {
res[ii] = ac_math::ac_softplus_pwl(data[ii]);
}
}

#endif
Expand Down Expand Up @@ -856,7 +862,9 @@ void softsign(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
template<class data_T, class res_T, typename CONFIG_T>
void softsign(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
assert("softsign not implemented in AC Math");
for (int ii=0; ii<CONFIG_T::n_in; ii++) {
res[ii] = ac_math::ac_softsign_pwl(data[ii]);
}
}

#endif
Expand Down Expand Up @@ -898,12 +906,19 @@ void init_elu_table(typename CONFIG_T::table_t table_out[N_TABLE])
#endif
}

#ifndef USE_AC_MATH

template<class data_T, class res_T, typename CONFIG_T>
void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in])
{
// Initialize the lookup table
#ifdef __HLS_SYN__
bool initialized = false;
typename CONFIG_T::table_t elu_table[CONFIG_T::table_size];
#else
static bool initialized = false;
static typename CONFIG_T::table_t elu_table[CONFIG_T::table_size];
#endif

if (!initialized) {
init_elu_table<CONFIG_T, CONFIG_T::table_size>(elu_table);
Expand All @@ -928,22 +943,18 @@ void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_
}
}

#else

template<class data_T, class res_T, typename CONFIG_T>
void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in])
{
typedef ac_fixed<res_T::width, res_T::i_width, false, res_T::q_mode, res_T::o_mode> res_unsigned_T;
res_unsigned_T x;
for (int i=0; i<CONFIG_T::n_in; i++) {
if (data[i] > 0) {
res[i] = data[i];
}
else {
x = ac_math::ac_exp_pwl<res_unsigned_T>(data[i]);
res[i] = x - ac_fixed<1, 1, false>(1.0);
}
for (int ii=0; ii<CONFIG_T::n_in; ii++) {
res[ii] = ac_math::ac_elu_pwl(data[ii], alpha);
}
}

#endif

// *************************************************
// SELU Activation
// *************************************************
Expand Down Expand Up @@ -1022,7 +1033,9 @@ void selu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
template<class data_T, class res_T, typename CONFIG_T>
void selu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
assert("selu not implemented in AC Math");
for (int ii=0; ii<CONFIG_T::n_in; ii++) {
res[ii] = ac_math::ac_selu_pwl(data[ii]);
}
}

#endif
Expand Down
62 changes: 44 additions & 18 deletions hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
#include <ac_math/ac_tanh_pwl.h>
#include <ac_math/ac_sigmoid_pwl.h>
#include <ac_math/ac_pow_pwl.h>
#include <ac_math/ac_elu_pwl.h>
#include <ac_math/ac_selu_pwl.h>
#include <ac_math/ac_softplus_pwl.h>
#include <ac_math/ac_softsign_pwl.h>
#include "nnet_common.h"
#include "nnet_types.h"
#include "nnet_stream.h"
Expand Down Expand Up @@ -631,7 +635,15 @@ void softplus(ac_channel<data_T> &data, ac_channel<res_T> &res) {
template<class data_T, class res_T, typename CONFIG_T>
void softplus(ac_channel<data_T> &data, ac_channel<res_T> &res)
{
assert("softplus stream not implemented for AC Math");
SoftplusActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
data_T in_data = data.read();
res_T out_data;
#pragma hls_unroll
SoftplusPackLoop: for (int j = 0; j < res_T::size; j++) {
ac_math::ac_softplus_pwl(in_data[j],out_data[j]);
}
res.write(out_data);
}
}

#endif
Expand Down Expand Up @@ -684,7 +696,15 @@ void softsign(ac_channel<data_T> &data, ac_channel<res_T> &res) {
template<class data_T, class res_T, typename CONFIG_T>
void softsign(ac_channel<data_T> &data, ac_channel<res_T> &res)
{
assert("softsign stream not implemented for AC Math");
SoftsignActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
data_T in_data = data.read();
res_T out_data;
#pragma hls_unroll
SoftsignPackLoop: for (int j = 0; j < res_T::size; j++) {
ac_math::ac_softsign_pwl(in_data[j],out_data[j]);
}
res.write(out_data);
}
}

#endif
Expand All @@ -694,11 +714,18 @@ assert("softsign stream not implemented for AC Math");
// ELU Activation
// *************************************************

#ifndef USE_AC_MATH

template<class data_T, class res_T, typename CONFIG_T>
void elu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel<res_T> &res) {
// Initialize the lookup table
#ifdef __HLS_SYN__
bool initialized = false;
typename CONFIG_T::table_t elu_table[CONFIG_T::table_size];
#else
static bool initialized = false;
static typename CONFIG_T::table_t elu_table[CONFIG_T::table_size];
#endif

if (!initialized) {
init_elu_table<CONFIG_T, CONFIG_T::table_size>(elu_table);
Expand Down Expand Up @@ -731,33 +758,24 @@ void elu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel
}
}

#else

template<class data_T, class res_T, typename CONFIG_T>
void elu(ac_channel<data_T> &data, ac_channel<res_T> &res)
void elu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel<res_T> &res)
{
//assert("elu stream not implemented for AC Math");
#pragma hls_pipeline_init_interval 1
EluActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {

data_T in_data = data.read();
res_T out_data;
typedef class res_T::value_type out_T;
typedef ac_fixed<out_T::width, out_T::i_width, false, out_T::q_mode, out_T::o_mode> out_unsigned_T;
out_unsigned_T x;
#pragma hls_unroll
EluPackLoop: for (int j = 0; j < res_T::size; j++) {
//int data_round = in_data[j]*CONFIG_T::table_size/8;
if (in_data[j] > 0) {
out_data[j] = in_data[j];
}
else {
x = ac_math::ac_exp_pwl<out_unsigned_T>(in_data[j]);
out_data[j] = x - ac_fixed<1, 1, false>(1.0);
}
ac_math::ac_elu_pwl(in_data[j],out_data[j],alpha);
}
res.write(out_data);
}
}

#endif

// *************************************************
// SELU Activation
// *************************************************
Expand Down Expand Up @@ -809,7 +827,15 @@ void selu(ac_channel<data_T> &data, ac_channel<res_T> &res) {
template<class data_T, class res_T, typename CONFIG_T>
void selu(ac_channel<data_T> &data, ac_channel<res_T> &res)
{
assert("selu stream not implemented for AC Math");
SeluActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
data_T in_data = data.read();
res_T out_data;
#pragma hls_unroll
SeluPackLoop: for (int j = 0; j < res_T::size; j++) {
ac_math::ac_selu_pwl(in_data[j],out_data[j]);
}
res.write(out_data);
}
}

#endif
Expand Down

0 comments on commit b82e83d

Please sign in to comment.