Skip to content

Commit 4477bf7

Browse files
committed
Add check for NaN rounding
1 parent 3213723 commit 4477bf7

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ public static short floatToBFloat16(float f) {
2323
// denormal - zero exp, non-zero fraction
2424
// infinity - all-1 exp, zero fraction
2525
// NaN - all-1 exp, non-zero fraction
26-
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
27-
// infinities
26+
27+
// note that floatToIntBits doesn't maintain specific NaN values,
28+
// unlike floatToRawIntBits, but instead can return different NaN bit patterns.
29+
// this means that a NaN is unlikely to be turned into infinity by rounding
2830

2931
int bits = Float.floatToIntBits(f);
3032
int bfloat16 = bits >>> 16;

server/src/test/java/org/elasticsearch/index/codec/vectors/BFloat16Tests.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ public void testRoundToEven() {
5252
assertRounding(construct(0b000000000, 0b0000000_10000000_00000000), 0f);
5353

5454
// rounding the standard NaN value should be unchanged
55-
assertThat(Float.floatToIntBits(BFloat16.truncateToBFloat16(Float.NaN)), equalTo(Float.floatToIntBits(Float.NaN)));
55+
assertThat(Float.floatToRawIntBits(BFloat16.truncateToBFloat16(Float.NaN)), equalTo(Float.floatToRawIntBits(Float.NaN)));
56+
57+
// you would expect this to be turned into infinity due to overflow, but instead
58+
// it stays a NaN with a different bit pattern due to using floatToIntBits rather than floatToRawIntBits
59+
// inside floatToBFloat16
60+
assertTrue(Float.isNaN(BFloat16.truncateToBFloat16(construct(0b011111111, 0b0000000_10000000_00000000))));
5661
}
5762

5863
private static float construct(int exp, int mantissa) {
@@ -71,8 +76,11 @@ private static void assertRounding(float value, float expectedRounded) {
7176
float rounded = BFloat16.truncateToBFloat16(value);
7277

7378
// System.out.println(value + " rounds to " + rounded);
74-
assertEquals(value + " rounded to " + rounded + ", not " + expectedRounded,
75-
Float.floatToIntBits(expectedRounded), Float.floatToIntBits(rounded));
79+
assertEquals(
80+
value + " rounded to " + rounded + ", not " + expectedRounded,
81+
Float.floatToIntBits(expectedRounded),
82+
Float.floatToIntBits(rounded)
83+
);
7684

7785
// there should not be a closer bfloat16 value (comparing using FP math) than the expected rounded value
7886
float delta = Math.abs(value - rounded);

0 commit comments

Comments
 (0)