1717from typing import Dict
1818
1919from ml_dtypes ._ml_dtypes_ext import bfloat16
20+ from ml_dtypes ._ml_dtypes_ext import float8_e3m4
2021from ml_dtypes ._ml_dtypes_ext import float8_e4m3
2122from ml_dtypes ._ml_dtypes_ext import float8_e4m3b11fnuz
2223from ml_dtypes ._ml_dtypes_ext import float8_e4m3fn
2627import numpy as np
2728
2829_bfloat16_dtype = np .dtype (bfloat16 )
30+ _float8_e3m4_dtype = np .dtype (float8_e3m4 )
2931_float8_e4m3_dtype = np .dtype (float8_e4m3 )
3032_float8_e4m3b11fnuz_dtype = np .dtype (float8_e4m3b11fnuz )
3133_float8_e4m3fn_dtype = np .dtype (float8_e4m3fn )
@@ -43,12 +45,21 @@ def __init__(self):
4345 self .smallest_subnormal = bfloat16 (smallest_subnormal )
4446
4547
48+ class _Float8E3m4MachArLike :
49+
50+ def __init__ (self ):
51+ smallest_normal = float .fromhex ("0x1p-2" )
52+ self .smallest_normal = float8_e3m4 (smallest_normal )
53+ smallest_subnormal = float .fromhex ("0x0.1p-2" )
54+ self .smallest_subnormal = float8_e3m4 (smallest_subnormal )
55+
56+
4657class _Float8E4m3MachArLike :
4758
4859 def __init__ (self ):
4960 smallest_normal = float .fromhex ("0x1p-6" )
5061 self .smallest_normal = float8_e4m3 (smallest_normal )
51- smallest_subnormal = float .fromhex ("0x1p-9 " )
62+ smallest_subnormal = float .fromhex ("0x0.2p-6 " )
5263 self .smallest_subnormal = float8_e4m3 (smallest_subnormal )
5364
5465
@@ -146,6 +157,51 @@ def float_to_str(f):
146157 # pylint: enable=protected-access
147158 return obj
148159
160+ @staticmethod
161+ def _float8_e3m4_finfo ():
162+ def float_to_str (f ):
163+ return "%6.2e" % float (f )
164+
165+ tiny = float .fromhex ("0x1p-2" ) # 1/4 min normal
166+ resolution = 0.1
167+ eps = float .fromhex ("0x1p-4" ) # 1/16
168+ epsneg = float .fromhex ("0x1p-5" ) # 1/32
169+ max_ = float .fromhex ("0x1.Fp3" ) # 15.5 max normal
170+
171+ obj = object .__new__ (np .finfo )
172+ obj .dtype = _float8_e3m4_dtype
173+ obj .bits = 8
174+ obj .eps = float8_e3m4 (eps )
175+ obj .epsneg = float8_e3m4 (epsneg )
176+ obj .machep = - 4
177+ obj .negep = - 5
178+ obj .max = float8_e3m4 (max_ )
179+ obj .min = float8_e3m4 (- max_ )
180+ obj .nexp = 3
181+ obj .nmant = 4
182+ obj .iexp = obj .nexp
183+ obj .maxexp = 4
184+ obj .minexp = - 2
185+ obj .precision = 1
186+ obj .resolution = float8_e3m4 (resolution )
187+ # pylint: disable=protected-access
188+ obj ._machar = _Float8E3m4MachArLike ()
189+ if not hasattr (obj , "tiny" ):
190+ obj .tiny = float8_e3m4 (tiny )
191+ if not hasattr (obj , "smallest_normal" ):
192+ obj .smallest_normal = obj ._machar .smallest_normal
193+ obj .smallest_subnormal = obj ._machar .smallest_subnormal
194+
195+ obj ._str_tiny = float_to_str (tiny )
196+ obj ._str_smallest_normal = float_to_str (tiny )
197+ obj ._str_smallest_subnormal = float_to_str (obj .smallest_subnormal )
198+ obj ._str_max = float_to_str (max_ )
199+ obj ._str_epsneg = float_to_str (epsneg )
200+ obj ._str_eps = float_to_str (eps )
201+ obj ._str_resolution = float_to_str (resolution )
202+ # pylint: enable=protected-access
203+ return obj
204+
149205 @staticmethod
150206 def _float8_e4m3_finfo ():
151207 def float_to_str (f ):
@@ -425,6 +481,14 @@ def __new__(cls, dtype):
425481 if _bfloat16_dtype not in cls ._finfo_cache :
426482 cls ._finfo_cache [_bfloat16_dtype ] = cls ._bfloat16_finfo ()
427483 return cls ._finfo_cache [_bfloat16_dtype ]
484+ if (
485+ isinstance (dtype , str )
486+ and dtype == "float8_e3m4"
487+ or dtype == _float8_e3m4_dtype
488+ ):
489+ if _float8_e3m4_dtype not in cls ._finfo_cache :
490+ cls ._finfo_cache [_float8_e3m4_dtype ] = cls ._float8_e3m4_finfo ()
491+ return cls ._finfo_cache [_float8_e3m4_dtype ]
428492 if (
429493 isinstance (dtype , str )
430494 and dtype == "float8_e4m3"
0 commit comments