diff --git a/calyx-frontend/src/parser.rs b/calyx-frontend/src/parser.rs index 4de02c1148..2a48c2540e 100644 --- a/calyx-frontend/src/parser.rs +++ b/calyx-frontend/src/parser.rs @@ -294,6 +294,29 @@ impl CalyxParser { u64::from_str_radix(input.as_str(), 2) .map_err(|_| input.error("Expected binary number")) } + fn float(input: Node) -> ParseResult { + let float_str = input.as_str(); + let float_val: f64 = float_str + .parse() + .map_err(|_| input.error("Expected valid floating-point number"))?; + Ok(float_val) + } + + fn ieee754_const(input: Node) -> ParseResult { + println!("parsing iee754"); + let span = Self::get_span(&input); + let val = match_nodes!( + input.clone().into_children(); + [float(val)] => val + ); + let bit_pattern = val.to_bits(); + Ok(BitNum { + width: 64, + num_type: NumType::Hex, + val: bit_pattern, + span, + }) + } fn num_lit(input: Node) -> ParseResult { let span = Self::get_span(&input); @@ -323,6 +346,7 @@ impl CalyxParser { val, span }, + [ieee754_const(val)] => val, ); diff --git a/calyx-frontend/src/syntax.pest b/calyx-frontend/src/syntax.pest index 90b85263f6..6dc5b80f93 100644 --- a/calyx-frontend/src/syntax.pest +++ b/calyx-frontend/src/syntax.pest @@ -16,17 +16,22 @@ binary = @{ ASCII_HEX_DIGIT+ } decimal = @{ ASCII_HEX_DIGIT+ } octal = @{ ASCII_HEX_DIGIT+ } hex = @{ ASCII_HEX_DIGIT+ } +float = @{ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ } + +ieee754_const = { "ieee754_const(" ~ float ~ ")" } // `$` creates a compound rule which ignores whitespace while allowing for // inner rules (`@` makes inner rules silent). // See: https://pest.rs/book/print.html#atomic num_lit = ${ - bitwidth + (bitwidth ~ "'" ~ ( "d" ~ decimal | "b" ~ binary | "x" ~ hex - | "o" ~ octal ) + | "o" ~ octal) + ) + | ieee754_const } char = { !"\"" ~ ANY } diff --git a/calyx-py/calyx/numeric_types.py b/calyx-py/calyx/numeric_types.py index da5c901ed0..56cbbd9f53 100644 --- a/calyx-py/calyx/numeric_types.py +++ b/calyx-py/calyx/numeric_types.py @@ -7,6 +7,7 @@ from decimal import Decimal, getcontext import math import logging as log +import struct class InvalidNumericType(Exception): @@ -335,3 +336,23 @@ def bitnum_to_fixed(bitnum: Bitnum, int_width: int) -> FixedPoint: int_width=int_width, is_signed=bitnum.is_signed, ) + + +@dataclass +class FloatingPoint(NumericType): + """Represents a floating point number.""" + + def __init__(self, value: str, width: int, is_signed: bool): + super().__init__(value, width, is_signed) + + if self.bit_string_repr is None and self.hex_string_repr is None: + # The decimal representation was passed in. + packed = struct.pack('!f', float(self.string_repr)) + unpacked = struct.unpack('>I', packed)[0] + self.bit_string_repr = f'{unpacked:0{self.width}b}' + self.uint_repr = int(self.bit_string_repr, 2) + self.hex_string_repr = np.base_repr(self.uint_repr, 16) + + def to_dec(self, round_place: int): + float_value = struct.unpack('!f', int(self.bit_string_repr, 2).to_bytes(4, byteorder='big'))[0] + return round(float_value, round_place) diff --git a/fud/fud/stages/verilator/json_to_dat.py b/fud/fud/stages/verilator/json_to_dat.py index a988f62ff3..a7f1c230b9 100644 --- a/fud/fud/stages/verilator/json_to_dat.py +++ b/fud/fud/stages/verilator/json_to_dat.py @@ -1,6 +1,6 @@ import simplejson as sjson import numpy as np -from calyx.numeric_types import FixedPoint, Bitnum, InvalidNumericType +from calyx.numeric_types import FixedPoint, Bitnum, FloatingPoint, InvalidNumericType from pathlib import Path from fud.errors import Malformed import logging as log @@ -14,7 +14,7 @@ def float_to_fixed(value: float, N: int) -> float: return round(value * w) / float(w) -def parse_dat(path, args): +def parse_dat(path, is_bn, args): """Parses a number with the given numeric type arguments from the array at the given `path`. """ @@ -34,12 +34,15 @@ def parse(hex_value: str): hex_value = f"0x{hex_value}" if "int_width" in args: return FixedPoint(hex_value, **args).str_value() - else: + elif is_bn: bn = Bitnum(hex_value, **args) if bn.is_undef: return bn.str_value() else: return int(bn.str_value()) + else: + fp = FloatingPoint(hex_value, **args) + return fp.to_dec(round_place=2) with path.open("r") as f: lines = [] @@ -90,11 +93,14 @@ def provided(x, y): ) -def convert(x, round: bool, is_signed: bool, width: int, int_width=None): +def convert(x, round: bool, is_signed: bool, width: int, is_bn: bool, int_width=None): with_prefix = False # If `int_width` is not defined, then this is a `Bitnum` if int_width is None: - return Bitnum(x, width, is_signed).hex_string(with_prefix) + if is_bn: + return Bitnum(x, width, is_signed).hex_string(with_prefix) + else: + return FloatingPoint(x, width, is_signed).hex_string(with_prefix) try: return FixedPoint(x, width, int_width, is_signed).hex_string(with_prefix) @@ -133,14 +139,14 @@ def convert2dat(output_dir, data, extension, round: bool): numeric_type = format["numeric_type"] is_signed = format["is_signed"] - if numeric_type not in {"bitnum", "fixed_point"}: - raise InvalidNumericType('Fud only supports "fixed_point" and "bitnum".') + if numeric_type not in {"bitnum", "fixed_point", "floating_point"}: + raise InvalidNumericType('Fud only supports "fixed_point", "bitnum", and "floating_point".') is_fp = numeric_type == "fixed_point" if is_fp: width, int_width = parse_fp_widths(format) else: - # `Bitnum`s only have a bit width. + # `Bitnum`s and `FloatingPoint`s only have a bit width width = format["width"] int_width = None @@ -154,7 +160,8 @@ def convert2dat(output_dir, data, extension, round: bool): with path.open("w") as f: for v in arr.flatten(): - f.write(convert(v, round, is_signed, width, int_width) + "\n") + is_bn=numeric_type == "bitnum" + f.write(convert(v, round, is_signed, width, is_bn, int_width) + "\n") shape[k]["shape"] = list(arr.shape) shape[k]["numeric_type"] = numeric_type @@ -185,8 +192,9 @@ def convert2json(input_dir, extension): # for building the FixedPoint or Bitnum classes. args = form.copy() del args["shape"] + is_bn = args["numeric_type"] == "bitnum" del args["numeric_type"] - arr = parse_dat(path, args) + arr = parse_dat(path, is_bn, args) if form["shape"] == [0]: raise Malformed( "Data format shape", diff --git a/tests/correctness/float/float-const.data b/tests/correctness/float/float-const.data new file mode 100644 index 0000000000..5a00caae4f --- /dev/null +++ b/tests/correctness/float/float-const.data @@ -0,0 +1,22 @@ +{ + "mem_read": { + "data": [ + 4.2 + ], + "format": { + "is_signed": true, + "numeric_type": "floating_point", + "width": 32 + } + }, + "mem_write": { + "data": [ + 0.0 + ], + "format": { + "is_signed": true, + "numeric_type": "floating_point", + "width": 32 + } + } +} \ No newline at end of file diff --git a/tests/correctness/float/float-const.expect b/tests/correctness/float/float-const.expect new file mode 100644 index 0000000000..38ca613fe6 --- /dev/null +++ b/tests/correctness/float/float-const.expect @@ -0,0 +1,22 @@ +{ + "mem_read": { + "data": [ + 4.2 + ], + "format": { + "is_signed": true, + "numeric_type": "floating_point", + "width": 32 + } + }, + "mem_write": { + "data": [ + 4.2 + ], + "format": { + "is_signed": true, + "numeric_type": "floating_point", + "width": 32 + } + } +} \ No newline at end of file diff --git a/tests/correctness/float/float-const.futil b/tests/correctness/float/float-const.futil new file mode 100644 index 0000000000..969494e51c --- /dev/null +++ b/tests/correctness/float/float-const.futil @@ -0,0 +1,31 @@ +import "primitives/compile.futil"; +import "primitives/memories/comb.futil"; + +component main<"toplevel"=1,>(@clk clk: 1, @reset reset: 1, @go go: 1) -> (@done done: 1) { + cells { + reg0 = std_reg(32); + @external mem_read = comb_mem_d1(32, 1, 1); + @external mem_write = comb_mem_d1(32, 1, 1); + } + wires { + group read { + mem_read.addr0 = 1'b0; + reg0.in = ieee754_const(1.00); + reg0.write_en = 1'b1; + read[done] = reg0.done; + } + + group write { + mem_write.addr0 = 1'b0; + mem_write.write_en = 1'b1; + mem_write.write_data = reg0.out; + write[done] = mem_write.done; + } + } + control { + seq { + read; + write; + } + } +} \ No newline at end of file