@@ -52,7 +52,12 @@ public void testRoundToEven() {
5252 assertRounding (construct (0b000000000, 0b0000000_10000000_00000000), 0f );
5353
5454 // rounding the standard NaN value should be unchanged
55- assertThat (Float .floatToIntBits (BFloat16 .truncateToBFloat16 (Float .NaN )), equalTo (Float .floatToIntBits (Float .NaN )));
55+ assertThat (Float .floatToRawIntBits (BFloat16 .truncateToBFloat16 (Float .NaN )), equalTo (Float .floatToRawIntBits (Float .NaN )));
56+
57+ // you would expect this to be turned into infinity due to overflow, but instead
58+ // it stays a NaN with a different bit pattern due to using floatToIntBits rather than floatToRawIntBits
59+ // inside floatToBFloat16
60+ assertTrue (Float .isNaN (BFloat16 .truncateToBFloat16 (construct (0b011111111, 0b0000000_10000000_00000000))));
5661 }
5762
5863 private static float construct (int exp , int mantissa ) {
@@ -71,8 +76,11 @@ private static void assertRounding(float value, float expectedRounded) {
7176 float rounded = BFloat16 .truncateToBFloat16 (value );
7277
7378 // System.out.println(value + " rounds to " + rounded);
74- assertEquals (value + " rounded to " + rounded + ", not " + expectedRounded ,
75- Float .floatToIntBits (expectedRounded ), Float .floatToIntBits (rounded ));
79+ assertEquals (
80+ value + " rounded to " + rounded + ", not " + expectedRounded ,
81+ Float .floatToIntBits (expectedRounded ),
82+ Float .floatToIntBits (rounded )
83+ );
7684
7785 // there should not be a closer bfloat16 value (comparing using FP math) than the expected rounded value
7886 float delta = Math .abs (value - rounded );
0 commit comments