Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions ggml/src/ggml-hexagon/htp/hvx-exp.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@
#include "hvx-utils.h"
#include "ops-utils.h"

static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) {
static const float kInf = INFINITY;
static const float kMaxExp = 88.02f; // log(INF)

const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);

HVX_Vector out = hvx_vec_exp_fp32(in_vec);

Expand All @@ -47,6 +42,12 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int

HVX_Vector vec_out = Q6_V_vzero();

static const float kInf = INFINITY;
static const float kMaxExp = 88.02f; // log(INF)

const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);

if (0 == unaligned_loop) {
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
Expand All @@ -55,9 +56,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in);
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
} else {
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++);
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf);
}
}
} else {
Expand All @@ -67,9 +68,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int

if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
} else {
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf);
}
}
}
Expand All @@ -83,9 +84,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);

vec_out = hvx_vec_exp_fp32_guard(neg_vec_in);
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
} else {
vec_out = hvx_vec_exp_fp32_guard(in);
vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf);
}

hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
Expand Down
18 changes: 15 additions & 3 deletions ggml/src/ggml-hexagon/htp/hvx-inverse.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
#include "hvx-utils.h"
#include "ops-utils.h"

static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);

HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);

return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
}

void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
Expand All @@ -32,19 +41,22 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
}

static const uint32_t kNanInfMask = 0x7f800000;
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask);

if (0 == unaligned_loop) {
HVX_Vector * p_vec_in = (HVX_Vector *) src;
HVX_Vector * p_vec_out = (HVX_Vector *) dst;

#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++);
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask);
}
} else {
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
}
}

Expand All @@ -53,7 +65,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
float * dstf = (float *) dst + num_elems_whole;

HVX_Vector in = *(HVX_UVector *) srcf;
HVX_Vector out = hvx_vec_inverse_fp32_guard(in);
HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);

hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
}
Expand Down
41 changes: 16 additions & 25 deletions ggml/src/ggml-hexagon/htp/hvx-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,24 +726,6 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
return Q6_Vsf_equals_Vqf32(r_qf);
}

static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) {
static const float kInf = INFINITY;
static const uint32_t kNanMask = 0x7fffffff;
static const uint32_t kNanMin = 0x7f800000;

const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf);

HVX_Vector out = hvx_vec_inverse_fp32(v_sf);

const HVX_Vector nan_mask = Q6_V_vsplat_R(kNanMask);
const HVX_Vector nan_min = Q6_V_vsplat_R(kNanMin);
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_mask);
const HVX_VectorPred pred = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out);

return Q6_V_vmux_QVV(pred, out, Q6_V_vzero());
}

#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
Expand Down Expand Up @@ -958,14 +940,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
return Q6_Vsf_equals_Vqf32(temp);
}

static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) {
static const float kMaxExp = -88.02f; // log(INF)

const HVX_Vector max_exp = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp));
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp);
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
HVX_Vector one,
HVX_Vector max_exp,
HVX_Vector min_exp) {
const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);

HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v);
return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero());
out = Q6_V_vmux_QVV(pred_max, out, one);
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
}

static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
Expand All @@ -977,9 +961,16 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;

static const float kMinExp = -87.f; // 0
static const float kMaxExp = 87.f; // 1

const HVX_Vector one = hvx_vec_splat_fp32(1.f);
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);

#pragma unroll(4)
for (int i = 0; i < step_of_1; i++) {
v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]);
v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp);
}
}

Expand Down
Loading