Skip to content

Commit

Permalink
Refactor DRUM HTP model
Browse files Browse the repository at this point in the history
  • Loading branch information
etrommer committed Nov 15, 2023
1 parent db33837 commit a5728e0
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/torchapprox/operators/htp_models/htp_models_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,18 @@ def htp_mitchell_trunc(base_func, op1, op2, kwargs):
def htp_drum(base_func, op1, op2, kwargs):
EPS = torch.tensor([1e-6])
k = 4
a = torch.floor(torch.log2(torch.maximum(torch.abs(op1), EPS)))
b = torch.floor(torch.log2(torch.maximum(torch.abs(op2), EPS)))
a = torch.maximum(2 ** (a - k + 2), torch.tensor([1]))
b = torch.maximum(2 ** (b - k + 2), torch.tensor([1]))

op1 -= torch.fmod(op1, a)
op2 -= torch.fmod(op2, b)

a = torch.where(op1 > 0, a, -a)
b = torch.where(op2 > 0, b, -b)

# Debiasing
op1 += a / 2
op2 += b / 2
def transform_operand(op):
sgn = op < 0
op = torch.abs(op)
rem = torch.floor(torch.log2(torch.maximum(op, EPS)))
rem = torch.maximum(2 ** (rem - k + 1), 0)
op -= op % rem
op += rem / 2
op = torch.where(sgn, -op, op)

op1 = transform_operand(op1)
op2 = transform_operand(op2)

res = base_func(op1, op2, **kwargs)
return res
Expand Down

0 comments on commit a5728e0

Please sign in to comment.