Skip to content

Commit

Permalink
Update HTP models
Browse files Browse the repository at this point in the history
This updates the HTP models to the version that were evaluated
  • Loading branch information
etrommer committed Dec 6, 2023
1 parent f1bd009 commit b655c53
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/torchapprox/operators/htp_models/htp_models_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ def accurate_reference(base_func, op1, op2, kwargs):

def htp_mitchell_trunc(base_func, op1, op2, kwargs):
EPS = torch.tensor([1e-6]).cuda()
k = 3
k = 5

def transform_operand(op):
sgn = op < 0
op = torch.abs(op)
rem = torch.floor(torch.log2(torch.maximum(op, EPS)))
rem = 2 ** (rem - k + 2)
op -= torch.where(rem > 8, op % rem, op)
op += torch.where(rem > 8, rem / 2, 0)
rem = 2 ** (rem - k + 1)
op -= torch.where(op > 2 ** (k + 1), op % rem, 0)
op = torch.where(sgn, -op, op)
return op

Expand All @@ -31,13 +30,13 @@ def transform_operand(op):

def htp_drum(base_func, op1, op2, kwargs):
EPS = torch.tensor([1e-6]).cuda()
k = 3
k = 5

def transform_operand(op):
sgn = op < 0
op = torch.abs(op)
rem = torch.floor(torch.log2(torch.maximum(op, EPS)))
rem = torch.maximum(2 ** (rem - k + 1), torch.tensor([1]).cuda())
rem = 2 ** (rem - k + 2)
op -= torch.where(op > 2 ** (k + 1), op % rem, 0)
op += torch.where(op > 2 ** (k + 1), rem / 2, 0)
op = torch.where(sgn, -op, op)
Expand Down

0 comments on commit b655c53

Please sign in to comment.