Skip to content

Commit 82f3a61

Browse files
author
The ml_dtypes Authors
committed
Merge pull request #171 from apivovarov:e3m4
PiperOrigin-RevId: 668214541
2 parents f053b3c + 4a03c71 commit 82f3a61

File tree

9 files changed

+377
-17
lines changed

9 files changed

+377
-17
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26-
* Added new 8-bit float type following IEEE 754 convention:
27-
`ml_dtypes.float8_e4m3`.
26+
* Added new 8-bit float types following IEEE 754 convention:
27+
`ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`.
2828
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.
2929

3030
## [0.4.0] - 2024-04-1

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
1111
- `float8_*`: several experimental 8-bit floating point representations
1212
including:
13+
* `float8_e3m4`
1314
* `float8_e4m3`
1415
* `float8_e4m3b11fnuz`
1516
* `float8_e4m3fn`
@@ -65,6 +66,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.
6566

6667
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
6768

69+
### `float8_e3m4`
70+
71+
Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.
72+
6873
### `float8_e4m3`
6974

7075
Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.

ml_dtypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"__version__",
1818
"bfloat16",
1919
"finfo",
20+
"float8_e3m4",
2021
"float8_e4m3",
2122
"float8_e4m3b11fnuz",
2223
"float8_e4m3fn",
@@ -35,6 +36,7 @@
3536
from ml_dtypes._finfo import finfo
3637
from ml_dtypes._iinfo import iinfo
3738
from ml_dtypes._ml_dtypes_ext import bfloat16
39+
from ml_dtypes._ml_dtypes_ext import float8_e3m4
3840
from ml_dtypes._ml_dtypes_ext import float8_e4m3
3941
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
4042
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
@@ -48,6 +50,7 @@
4850
import numpy as np
4951

5052
bfloat16: Type[np.generic]
53+
float8_e3m4: Type[np.generic]
5154
float8_e4m3: Type[np.generic]
5255
float8_e4m3b11fnuz: Type[np.generic]
5356
float8_e4m3fn: Type[np.generic]

ml_dtypes/_finfo.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Dict
1818

1919
from ml_dtypes._ml_dtypes_ext import bfloat16
20+
from ml_dtypes._ml_dtypes_ext import float8_e3m4
2021
from ml_dtypes._ml_dtypes_ext import float8_e4m3
2122
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
2223
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
@@ -26,6 +27,7 @@
2627
import numpy as np
2728

2829
_bfloat16_dtype = np.dtype(bfloat16)
30+
_float8_e3m4_dtype = np.dtype(float8_e3m4)
2931
_float8_e4m3_dtype = np.dtype(float8_e4m3)
3032
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
3133
_float8_e4m3fn_dtype = np.dtype(float8_e4m3fn)
@@ -43,12 +45,21 @@ def __init__(self):
4345
self.smallest_subnormal = bfloat16(smallest_subnormal)
4446

4547

48+
class _Float8E3m4MachArLike:
49+
50+
def __init__(self):
51+
smallest_normal = float.fromhex("0x1p-2")
52+
self.smallest_normal = float8_e3m4(smallest_normal)
53+
smallest_subnormal = float.fromhex("0x0.1p-2")
54+
self.smallest_subnormal = float8_e3m4(smallest_subnormal)
55+
56+
4657
class _Float8E4m3MachArLike:
4758

4859
def __init__(self):
4960
smallest_normal = float.fromhex("0x1p-6")
5061
self.smallest_normal = float8_e4m3(smallest_normal)
51-
smallest_subnormal = float.fromhex("0x1p-9")
62+
smallest_subnormal = float.fromhex("0x0.2p-6")
5263
self.smallest_subnormal = float8_e4m3(smallest_subnormal)
5364

5465

@@ -146,6 +157,51 @@ def float_to_str(f):
146157
# pylint: enable=protected-access
147158
return obj
148159

160+
@staticmethod
161+
def _float8_e3m4_finfo():
162+
def float_to_str(f):
163+
return "%6.2e" % float(f)
164+
165+
tiny = float.fromhex("0x1p-2") # 1/4 min normal
166+
resolution = 0.1
167+
eps = float.fromhex("0x1p-4") # 1/16
168+
epsneg = float.fromhex("0x1p-5") # 1/32
169+
max_ = float.fromhex("0x1.Fp3") # 15.5 max normal
170+
171+
obj = object.__new__(np.finfo)
172+
obj.dtype = _float8_e3m4_dtype
173+
obj.bits = 8
174+
obj.eps = float8_e3m4(eps)
175+
obj.epsneg = float8_e3m4(epsneg)
176+
obj.machep = -4
177+
obj.negep = -5
178+
obj.max = float8_e3m4(max_)
179+
obj.min = float8_e3m4(-max_)
180+
obj.nexp = 3
181+
obj.nmant = 4
182+
obj.iexp = obj.nexp
183+
obj.maxexp = 4
184+
obj.minexp = -2
185+
obj.precision = 1
186+
obj.resolution = float8_e3m4(resolution)
187+
# pylint: disable=protected-access
188+
obj._machar = _Float8E3m4MachArLike()
189+
if not hasattr(obj, "tiny"):
190+
obj.tiny = float8_e3m4(tiny)
191+
if not hasattr(obj, "smallest_normal"):
192+
obj.smallest_normal = obj._machar.smallest_normal
193+
obj.smallest_subnormal = obj._machar.smallest_subnormal
194+
195+
obj._str_tiny = float_to_str(tiny)
196+
obj._str_smallest_normal = float_to_str(tiny)
197+
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
198+
obj._str_max = float_to_str(max_)
199+
obj._str_epsneg = float_to_str(epsneg)
200+
obj._str_eps = float_to_str(eps)
201+
obj._str_resolution = float_to_str(resolution)
202+
# pylint: enable=protected-access
203+
return obj
204+
149205
@staticmethod
150206
def _float8_e4m3_finfo():
151207
def float_to_str(f):
@@ -425,6 +481,14 @@ def __new__(cls, dtype):
425481
if _bfloat16_dtype not in cls._finfo_cache:
426482
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
427483
return cls._finfo_cache[_bfloat16_dtype]
484+
if (
485+
isinstance(dtype, str)
486+
and dtype == "float8_e3m4"
487+
or dtype == _float8_e3m4_dtype
488+
):
489+
if _float8_e3m4_dtype not in cls._finfo_cache:
490+
cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo()
491+
return cls._finfo_cache[_float8_e3m4_dtype]
428492
if (
429493
isinstance(dtype, str)
430494
and dtype == "float8_e4m3"

ml_dtypes/_src/dtypes.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ struct TypeDescriptor<bfloat16> : CustomFloatType<bfloat16> {
6060
static constexpr char kNpyDescrByteorder = '=';
6161
};
6262

63+
template <>
64+
struct TypeDescriptor<float8_e3m4> : CustomFloatType<float8_e3m4> {
65+
typedef float8_e3m4 T;
66+
static constexpr bool is_floating = true;
67+
static constexpr bool is_integral = false;
68+
static constexpr const char* kTypeName = "float8_e3m4";
69+
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e3m4";
70+
static constexpr const char* kTpDoc = "float8_e3m4 floating-point values";
71+
// Set e3m4 kind as Void since kind=f (float) with itemsize=1 is used by e5m2
72+
static constexpr char kNpyDescrKind = 'V'; // Void
73+
static constexpr char kNpyDescrType = '3';
74+
static constexpr char kNpyDescrByteorder = '='; // Native byte order
75+
};
76+
6377
template <>
6478
struct TypeDescriptor<float8_e4m3> : CustomFloatType<float8_e4m3> {
6579
typedef float8_e4m3 T;
@@ -283,6 +297,9 @@ bool Initialize() {
283297
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
284298
return false;
285299
}
300+
if (!RegisterFloatDtype<float8_e3m4>(numpy.get())) {
301+
return false;
302+
}
286303
if (!RegisterFloatDtype<float8_e4m3>(numpy.get())) {
287304
return false;
288305
}
@@ -342,6 +359,13 @@ bool Initialize() {
342359
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fnuz, float>();
343360
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fn, float>();
344361
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2, float>();
362+
success &= RegisterTwoWayCustomCast<float8_e3m4, bfloat16, float>();
363+
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3b11fnuz, float>();
364+
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e5m2fnuz, float>();
365+
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3fnuz, float>();
366+
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3fn, float>();
367+
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e5m2, float>();
368+
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3, float>();
345369
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
346370
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
347371
return success;
@@ -372,6 +396,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
372396
return nullptr;
373397
}
374398

399+
if (PyObject_SetAttrString(m.get(), "float8_e3m4",
400+
reinterpret_cast<PyObject*>(
401+
TypeDescriptor<float8_e3m4>::type_ptr)) < 0) {
402+
return nullptr;
403+
}
375404
if (PyObject_SetAttrString(m.get(), "float8_e4m3",
376405
reinterpret_cast<PyObject*>(
377406
TypeDescriptor<float8_e4m3>::type_ptr)) < 0) {

0 commit comments

Comments
 (0)