Skip to content

Commit

Permalink
Replace VectorXx.Exp's edge case fallback with scalar processing (dot…
Browse files Browse the repository at this point in the history
…net#107886)

* Replace VectorXx.Exp's edge case fallback with scalar processing

The better, vectorized fix is more complex and can be done for .NET 10.

* Revert addition to Helpers.IsEqualWithTolerance
  • Loading branch information
stephentoub authored Sep 18, 2024
1 parent 128837d commit 24e7d1b
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@ public static Vector512<T> Invoke(Vector512<T> x)
private const ulong V_ARG_MAX = 0x40862000_00000000;
private const ulong V_DP64_BIAS = 1023;

private const double V_EXPF_MIN = -709.782712893384;
private const double V_EXPF_MAX = +709.782712893384;

private const double V_EXPF_HUGE = 6755399441055744;
private const double V_TBL_LN2 = 1.4426950408889634;

Expand All @@ -183,155 +180,145 @@ public static Vector512<T> Invoke(Vector512<T> x)

public static Vector128<double> Invoke(Vector128<double> x)
{
// x * (64.0 / ln(2))
Vector128<double> z = x * Vector128.Create(V_TBL_LN2);

Vector128<double> dn = z + Vector128.Create(V_EXPF_HUGE);
// Check if -709 < vx < 709
if (Vector128.LessThanOrEqualAll(Vector128.Abs(x).AsUInt64(), Vector128.Create(V_ARG_MAX)))
{
// x * (64.0 / ln(2))
Vector128<double> z = x * Vector128.Create(V_TBL_LN2);

// n = (int)z
Vector128<ulong> n = dn.AsUInt64();
Vector128<double> dn = z + Vector128.Create(V_EXPF_HUGE);

// dn = (double)n
dn -= Vector128.Create(V_EXPF_HUGE);
// n = (int)z
Vector128<ulong> n = dn.AsUInt64();

// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector128<double> r = x - (dn * Vector128.Create(V_LN2_HEAD)) - (dn * Vector128.Create(V_LN2_TAIL));
// dn = (double)n
dn -= Vector128.Create(V_EXPF_HUGE);

Vector128<double> r2 = r * r;
Vector128<double> r4 = r2 * r2;
Vector128<double> r8 = r4 * r4;
// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector128<double> r = x - (dn * Vector128.Create(V_LN2_HEAD)) - (dn * Vector128.Create(V_LN2_TAIL));

// Compute polynomial
Vector128<double> poly = ((Vector128.Create(C12) * r + Vector128.Create(C11)) * r2 +
Vector128.Create(C10) * r + Vector128.Create(C9)) * r8 +
((Vector128.Create(C8) * r + Vector128.Create(C7)) * r2 +
(Vector128.Create(C6) * r + Vector128.Create(C5))) * r4 +
((Vector128.Create(C4) * r + Vector128.Create(C3)) * r2 + (r + Vector128<double>.One));
Vector128<double> r2 = r * r;
Vector128<double> r4 = r2 * r2;
Vector128<double> r8 = r4 * r4;

// m = (n - j) / 64
// result = polynomial * 2^m
Vector128<double> ret = poly * ((n + Vector128.Create(V_DP64_BIAS)) << 52).AsDouble();
// Compute polynomial
Vector128<double> poly = ((Vector128.Create(C12) * r + Vector128.Create(C11)) * r2 +
Vector128.Create(C10) * r + Vector128.Create(C9)) * r8 +
((Vector128.Create(C8) * r + Vector128.Create(C7)) * r2 +
(Vector128.Create(C6) * r + Vector128.Create(C5))) * r4 +
((Vector128.Create(C4) * r + Vector128.Create(C3)) * r2 + (r + Vector128<double>.One));

// Check if -709 < vx < 709
if (Vector128.GreaterThanAny(Vector128.Abs(x).AsUInt64(), Vector128.Create(V_ARG_MAX)))
// m = (n - j) / 64
// result = polynomial * 2^m
return poly * ((n + Vector128.Create(V_DP64_BIAS)) << 52).AsDouble();
}
else
{
// (x > V_EXPF_MAX) ? double.PositiveInfinity : x
Vector128<double> infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX));

ret = Vector128.ConditionalSelect(
infinityMask,
Vector128.Create(double.PositiveInfinity),
ret
);
return ScalarFallback(x);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN)));
static Vector128<double> ScalarFallback(Vector128<double> x) =>
Vector128.Create(Math.Exp(x.GetElement(0)),
Math.Exp(x.GetElement(1)));
}

return ret;
}

public static Vector256<double> Invoke(Vector256<double> x)
{
// x * (64.0 / ln(2))
Vector256<double> z = x * Vector256.Create(V_TBL_LN2);

Vector256<double> dn = z + Vector256.Create(V_EXPF_HUGE);
// Check if -709 < vx < 709
if (Vector256.LessThanOrEqualAll(Vector256.Abs(x).AsUInt64(), Vector256.Create(V_ARG_MAX)))
{
// x * (64.0 / ln(2))
Vector256<double> z = x * Vector256.Create(V_TBL_LN2);

// n = (int)z
Vector256<ulong> n = dn.AsUInt64();
Vector256<double> dn = z + Vector256.Create(V_EXPF_HUGE);

// dn = (double)n
dn -= Vector256.Create(V_EXPF_HUGE);
// n = (int)z
Vector256<ulong> n = dn.AsUInt64();

// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector256<double> r = x - (dn * Vector256.Create(V_LN2_HEAD)) - (dn * Vector256.Create(V_LN2_TAIL));
// dn = (double)n
dn -= Vector256.Create(V_EXPF_HUGE);

Vector256<double> r2 = r * r;
Vector256<double> r4 = r2 * r2;
Vector256<double> r8 = r4 * r4;
// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector256<double> r = x - (dn * Vector256.Create(V_LN2_HEAD)) - (dn * Vector256.Create(V_LN2_TAIL));

// Compute polynomial
Vector256<double> poly = ((Vector256.Create(C12) * r + Vector256.Create(C11)) * r2 +
Vector256.Create(C10) * r + Vector256.Create(C9)) * r8 +
((Vector256.Create(C8) * r + Vector256.Create(C7)) * r2 +
(Vector256.Create(C6) * r + Vector256.Create(C5))) * r4 +
((Vector256.Create(C4) * r + Vector256.Create(C3)) * r2 + (r + Vector256<double>.One));
Vector256<double> r2 = r * r;
Vector256<double> r4 = r2 * r2;
Vector256<double> r8 = r4 * r4;

// m = (n - j) / 64
// result = polynomial * 2^m
Vector256<double> ret = poly * ((n + Vector256.Create(V_DP64_BIAS)) << 52).AsDouble();
// Compute polynomial
Vector256<double> poly = ((Vector256.Create(C12) * r + Vector256.Create(C11)) * r2 +
Vector256.Create(C10) * r + Vector256.Create(C9)) * r8 +
((Vector256.Create(C8) * r + Vector256.Create(C7)) * r2 +
(Vector256.Create(C6) * r + Vector256.Create(C5))) * r4 +
((Vector256.Create(C4) * r + Vector256.Create(C3)) * r2 + (r + Vector256<double>.One));

// Check if -709 < vx < 709
if (Vector256.GreaterThanAny(Vector256.Abs(x).AsUInt64(), Vector256.Create(V_ARG_MAX)))
// m = (n - j) / 64
// result = polynomial * 2^m
return poly * ((n + Vector256.Create(V_DP64_BIAS)) << 52).AsDouble();
}
else
{
// (x > V_EXPF_MAX) ? double.PositiveInfinity : x
Vector256<double> infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX));
return ScalarFallback(x);

ret = Vector256.ConditionalSelect(
infinityMask,
Vector256.Create(double.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN)));
static Vector256<double> ScalarFallback(Vector256<double> x) =>
Vector256.Create(Math.Exp(x.GetElement(0)),
Math.Exp(x.GetElement(1)),
Math.Exp(x.GetElement(2)),
Math.Exp(x.GetElement(3)));
}

return ret;
}

public static Vector512<double> Invoke(Vector512<double> x)
{
// x * (64.0 / ln(2))
Vector512<double> z = x * Vector512.Create(V_TBL_LN2);

Vector512<double> dn = z + Vector512.Create(V_EXPF_HUGE);
// Check if -709 < vx < 709
if (Vector512.LessThanOrEqualAll(Vector512.Abs(x).AsUInt64(), Vector512.Create(V_ARG_MAX)))
{
// x * (64.0 / ln(2))
Vector512<double> z = x * Vector512.Create(V_TBL_LN2);

// n = (int)z
Vector512<ulong> n = dn.AsUInt64();
Vector512<double> dn = z + Vector512.Create(V_EXPF_HUGE);

// dn = (double)n
dn -= Vector512.Create(V_EXPF_HUGE);
// n = (int)z
Vector512<ulong> n = dn.AsUInt64();

// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector512<double> r = x - (dn * Vector512.Create(V_LN2_HEAD)) - (dn * Vector512.Create(V_LN2_TAIL));
// dn = (double)n
dn -= Vector512.Create(V_EXPF_HUGE);

Vector512<double> r2 = r * r;
Vector512<double> r4 = r2 * r2;
Vector512<double> r8 = r4 * r4;
// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector512<double> r = x - (dn * Vector512.Create(V_LN2_HEAD)) - (dn * Vector512.Create(V_LN2_TAIL));

// Compute polynomial
Vector512<double> poly = ((Vector512.Create(C12) * r + Vector512.Create(C11)) * r2 +
Vector512.Create(C10) * r + Vector512.Create(C9)) * r8 +
((Vector512.Create(C8) * r + Vector512.Create(C7)) * r2 +
(Vector512.Create(C6) * r + Vector512.Create(C5))) * r4 +
((Vector512.Create(C4) * r + Vector512.Create(C3)) * r2 + (r + Vector512<double>.One));
Vector512<double> r2 = r * r;
Vector512<double> r4 = r2 * r2;
Vector512<double> r8 = r4 * r4;

// m = (n - j) / 64
// result = polynomial * 2^m
Vector512<double> ret = poly * ((n + Vector512.Create(V_DP64_BIAS)) << 52).AsDouble();
// Compute polynomial
Vector512<double> poly = ((Vector512.Create(C12) * r + Vector512.Create(C11)) * r2 +
Vector512.Create(C10) * r + Vector512.Create(C9)) * r8 +
((Vector512.Create(C8) * r + Vector512.Create(C7)) * r2 +
(Vector512.Create(C6) * r + Vector512.Create(C5))) * r4 +
((Vector512.Create(C4) * r + Vector512.Create(C3)) * r2 + (r + Vector512<double>.One));

// Check if -709 < vx < 709
if (Vector512.GreaterThanAny(Vector512.Abs(x).AsUInt64(), Vector512.Create(V_ARG_MAX)))
// m = (n - j) / 64
// result = polynomial * 2^m
return poly * ((n + Vector512.Create(V_DP64_BIAS)) << 52).AsDouble();
}
else
{
// (x > V_EXPF_MAX) ? double.PositiveInfinity : x
Vector512<double> infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX));

ret = Vector512.ConditionalSelect(
infinityMask,
Vector512.Create(double.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN)));
return ScalarFallback(x);

static Vector512<double> ScalarFallback(Vector512<double> x) =>
Vector512.Create(Math.Exp(x.GetElement(0)),
Math.Exp(x.GetElement(1)),
Math.Exp(x.GetElement(2)),
Math.Exp(x.GetElement(3)),
Math.Exp(x.GetElement(4)),
Math.Exp(x.GetElement(5)),
Math.Exp(x.GetElement(6)),
Math.Exp(x.GetElement(7)));
}

return ret;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,17 @@ protected T NextRandom(T avoid)
/// the value is stored into a random position in <paramref name="x"/>, and the original
/// value is subsequently restored.
/// </summary>
protected void RunForEachSpecialValue(Action action, BoundedMemory<T> x)
protected void RunForEachSpecialValue(Action action, BoundedMemory<T> x) =>
RunForEachSpecialValue(action, x, GetSpecialValues());

/// <summary>
/// Runs the specified action for each special value. Before the action is invoked,
/// the value is stored into a random position in <paramref name="x"/>, and the original
/// value is subsequently restored.
/// </summary>
protected void RunForEachSpecialValue(Action action, BoundedMemory<T> x, IEnumerable<T> specialValues)
{
Assert.All(GetSpecialValues(), value =>
Assert.All(specialValues, value =>
{
int pos = Random.Next(x.Length);
T orig = x[pos];
Expand Down Expand Up @@ -1021,14 +1029,25 @@ public void Exp_SpecialValues()
using BoundedMemory<T> x = CreateAndFillTensor(tensorLength);
using BoundedMemory<T> destination = CreateTensor(tensorLength);
T[] additionalSpecialValues =
[
typeof(T) == typeof(float) ? (T)(object)-709.7f :
typeof(T) == typeof(double) ? (T)(object)-709.7 :
default,
typeof(T) == typeof(float) ? (T)(object)709.7f :
typeof(T) == typeof(double) ? (T)(object)709.7 :
default,
];
RunForEachSpecialValue(() =>
{
Exp(x, destination);
for (int i = 0; i < tensorLength; i++)
{
AssertEqualTolerance(Exp(x[i]), destination[i]);
}
}, x);
}, x, GetSpecialValues().Concat(additionalSpecialValues));
});
}

Expand Down
Loading

0 comments on commit 24e7d1b

Please sign in to comment.