@@ -503,15 +503,18 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
503503 """
504504 if not fn :
505505 raise NotImplementedError ("fn=False is not implemented." )
506- b = int .from_bytes (struct .pack ("<f" , numpy .float32 (x )), "little" )
506+ if not isinstance (x , numpy .float32 ):
507+ x = numpy .float32 (x )
508+ b = int .from_bytes (struct .pack ("<f" , x ), "little" )
507509 ret = (b & 0x80000000 ) >> 24 # sign
508510 if uz :
509- if (b & 0x7FC00000 ) == 0x7FC00000 :
510- return 0x80
511- if numpy .isinf (x ):
511+ if (b & 0x7FFFFFFF ) == 0x7F800000 :
512+ # infinity
512513 if saturate :
513514 return ret | 127
514515 return 0x80
516+ if (b & 0x7F800000 ) == 0x7F800000 :
517+ return 0x80
515518 e = (b & 0x7F800000 ) >> 23 # exponent
516519 m = b & 0x007FFFFF # mantissa
517520
@@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
558561 ret = 0
559562 return int (ret )
560563 else :
561- if (b & 0x7FC00000 ) == 0x7FC00000 :
562- return 0x7F | ret
563- if numpy .isinf (x ):
564+ if (b & 0x7FFFFFFF ) == 0x7F800000 :
565+ # infinity
564566 if saturate :
565567 return ret | 126
566568 return 0x7F | ret
569+ if (b & 0x7F800000 ) == 0x7F800000 :
570+ # non
571+ return 0x7F | ret
567572 e = (b & 0x7F800000 ) >> 23 # exponent
568573 m = b & 0x007FFFFF # mantissa
569574
@@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
624629 ret = (b & 0x80000000 ) >> 24 # sign
625630
626631 if fn and uz :
627- if (b & 0x7FC00000 ) == 0x7FC00000 :
628- return 0x80
629632 if (b & 0x7FFFFFFF ) == 0x7F800000 :
630633 # inf
631634 if saturate :
632635 return ret | 0x7F
633636 return 0x80
637+ if (b & 0x7F800000 ) == 0x7F800000 :
638+ # nan
639+ return 0x80
634640 e = (b & 0x7F800000 ) >> 23 # exponent
635641 m = b & 0x007FFFFF # mantissa
636642
@@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
675681 ret = 0
676682 return int (ret )
677683 elif not fn and not uz :
678- if (b & 0x7FC00000 ) == 0x7FC00000 :
679- return 0x7F | ret
680- if numpy .isinf (x ):
684+ if (b & 0x7FFFFFFF ) == 0x7F800000 :
685+ # inf
681686 if saturate :
682687 return 0x7B | ret
683688 return 0x7C | ret
689+ if (b & 0x7F800000 ) == 0x7F800000 :
690+ # nan
691+ return 0x7F | ret
684692 e = (b & 0x7F800000 ) >> 23 # exponent
685693 m = b & 0x007FFFFF # mantissa
686694
0 commit comments