Skip to content

Commit

Permalink
support floating point as a numeric type so that we can write regular…
Browse files Browse the repository at this point in the history
… decimal representation of real numbers
  • Loading branch information
jiahanxie353 committed Jul 3, 2024
1 parent 75c5cb8 commit 29b1732
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 10 deletions.
21 changes: 21 additions & 0 deletions calyx-py/calyx/numeric_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from decimal import Decimal, getcontext
import math
import logging as log
import struct


class InvalidNumericType(Exception):
Expand Down Expand Up @@ -335,3 +336,23 @@ def bitnum_to_fixed(bitnum: Bitnum, int_width: int) -> FixedPoint:
int_width=int_width,
is_signed=bitnum.is_signed,
)


@dataclass
class FloatingPoint(NumericType):
"""Represents a floating point number."""

def __init__(self, value: str, width: int, is_signed: bool):
super().__init__(value, width, is_signed)

if self.bit_string_repr is None and self.hex_string_repr is None:
# The decimal representation was passed in.
packed = struct.pack('!f', float(self.string_repr))
unpacked = struct.unpack('>I', packed)[0]
self.bit_string_repr = f'{unpacked:0{self.width}b}'
self.uint_repr = int(self.bit_string_repr, 2)
self.hex_string_repr = np.base_repr(self.uint_repr, 16)

def to_dec(self, round_place: int):
float_value = struct.unpack('!f', int(self.bit_string_repr, 2).to_bytes(4, byteorder='big'))[0]
return round(float_value, round_place)
28 changes: 18 additions & 10 deletions fud/fud/stages/verilator/json_to_dat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import simplejson as sjson
import numpy as np
from calyx.numeric_types import FixedPoint, Bitnum, InvalidNumericType
from calyx.numeric_types import FixedPoint, Bitnum, FloatingPoint, InvalidNumericType
from pathlib import Path
from fud.errors import Malformed
import logging as log
Expand All @@ -14,7 +14,7 @@ def float_to_fixed(value: float, N: int) -> float:
return round(value * w) / float(w)


def parse_dat(path, args):
def parse_dat(path, is_bn, args):
"""Parses a number with the given numeric type
arguments from the array at the given `path`.
"""
Expand All @@ -34,12 +34,15 @@ def parse(hex_value: str):
hex_value = f"0x{hex_value}"
if "int_width" in args:
return FixedPoint(hex_value, **args).str_value()
else:
elif is_bn:
bn = Bitnum(hex_value, **args)
if bn.is_undef:
return bn.str_value()
else:
return int(bn.str_value())
else:
fp = FloatingPoint(hex_value, **args)
return fp.to_dec(round_place=2)

with path.open("r") as f:
lines = []
Expand Down Expand Up @@ -90,11 +93,14 @@ def provided(x, y):
)


def convert(x, round: bool, is_signed: bool, width: int, int_width=None):
def convert(x, round: bool, is_signed: bool, width: int, is_bn: bool, int_width=None):
with_prefix = False
# If `int_width` is not defined, then this is a `Bitnum`
if int_width is None:
return Bitnum(x, width, is_signed).hex_string(with_prefix)
if is_bn:
return Bitnum(x, width, is_signed).hex_string(with_prefix)
else:
return FloatingPoint(x, width, is_signed).hex_string(with_prefix)

try:
return FixedPoint(x, width, int_width, is_signed).hex_string(with_prefix)
Expand Down Expand Up @@ -133,14 +139,14 @@ def convert2dat(output_dir, data, extension, round: bool):
numeric_type = format["numeric_type"]
is_signed = format["is_signed"]

if numeric_type not in {"bitnum", "fixed_point"}:
raise InvalidNumericType('Fud only supports "fixed_point" and "bitnum".')
if numeric_type not in {"bitnum", "fixed_point", "floating_point"}:
raise InvalidNumericType('Fud only supports "fixed_point", "bitnum", and "floating_point".')

is_fp = numeric_type == "fixed_point"
if is_fp:
width, int_width = parse_fp_widths(format)
else:
# `Bitnum`s only have a bit width.
# `Bitnum`s and `FloatingPoint`s only have a bit width
width = format["width"]
int_width = None

Expand All @@ -154,7 +160,8 @@ def convert2dat(output_dir, data, extension, round: bool):

with path.open("w") as f:
for v in arr.flatten():
f.write(convert(v, round, is_signed, width, int_width) + "\n")
is_bn=numeric_type == "bitnum"
f.write(convert(v, round, is_signed, width, is_bn, int_width) + "\n")

shape[k]["shape"] = list(arr.shape)
shape[k]["numeric_type"] = numeric_type
Expand Down Expand Up @@ -185,8 +192,9 @@ def convert2json(input_dir, extension):
# for building the FixedPoint or Bitnum classes.
args = form.copy()
del args["shape"]
is_bn = args["numeric_type"] == "bitnum"
del args["numeric_type"]
arr = parse_dat(path, args)
arr = parse_dat(path, is_bn, args)
if form["shape"] == [0]:
raise Malformed(
"Data format shape",
Expand Down
22 changes: 22 additions & 0 deletions tests/correctness/float/float-const.data
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"mem_read": {
"data": [
4.2
],
"format": {
"is_signed": true,
"numeric_type": "floating_point",
"width": 32
}
},
"mem_write": {
"data": [
0.0
],
"format": {
"is_signed": true,
"numeric_type": "floating_point",
"width": 32
}
}
}
31 changes: 31 additions & 0 deletions tests/correctness/float/float-const.futil
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import "primitives/compile.futil";
import "primitives/memories/comb.futil";

component main<"toplevel"=1,>(@clk clk: 1, @reset reset: 1, @go go: 1) -> (@done done: 1) {
cells {
reg0 = std_reg(32);
@external mem_read = comb_mem_d1(32, 1, 1);
@external mem_write = comb_mem_d1(32, 1, 1);
}
wires {
group read {
mem_read.addr0 = 1'b0;
reg0.in = mem_read.read_data;
reg0.write_en = 1'b1;
read[done] = reg0.done;
}

group write {
mem_write.addr0 = 1'b0;
mem_write.write_en = 1'b1;
mem_write.write_data = reg0.out;
write[done] = mem_write.done;
}
}
control {
seq {
read;
write;
}
}
}

0 comments on commit 29b1732

Please sign in to comment.