ATLAS Offline Software
Loading...
Searching...
No Matches
FPCompressionUtils.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3
4 Header-only utilities for reduced-precision float compression.
5
6 Bit-truncation of the float32 representation, parameterised by exponent-
7 and mantissa-bit budgets (E, M):
8
9 - truncateToFloat(v, E, M) : RNE-truncated float32, low (31-E-M) bits
10 zero, exponent range clamped to what E bits can address.
11
12 Special case: (E=8, M=7) reproduces bfloat16 exactly (no exponent
13 clamping since E==8 covers float32's full range).
14
15 All functions use round-to-nearest-even.
16*/
17
18#ifndef FLAVORTAGINFERENCE_FPCOMPRESSIONUTILS_H
19#define FLAVORTAGINFERENCE_FPCOMPRESSIONUTILS_H
20
21#include <bit>
22#include <cmath>
23#include <cstdint>
24
25namespace FlavorTagInference {
27
28 // RNE-truncated float32 with (exp_bits, mantissa_bits) precision budget.
29 // Low (31-E-M) bits are zeroed. If exp_bits < 8, values outside the
30 // representable exponent range are saturated (over) or flushed to zero
31 // (under); NaN/Inf pass through.
32 inline float truncateToFloat(float val, int exp_bits, int mantissa_bits) {
33 uint32_t bits = std::bit_cast<uint32_t>(val);
34
35 // Exponent clamping (only meaningful if E < 8; float32 uses E=8).
36 if (exp_bits < 8) {
37 const int max_exp = (1 << (exp_bits - 1)) - 1; // e.g. E=5 -> 15
38 const int min_exp = -(max_exp - 1);
39 const int exp_raw = static_cast<int>((bits >> 23) & 0xFFu);
40 if (exp_raw != 0xFF) { // pass NaN/Inf through
41 const int actual_exp = exp_raw - 127;
42 if (actual_exp > max_exp) {
43 const uint32_t sign = bits & 0x80000000u;
44 const uint32_t sat_exp = static_cast<uint32_t>(max_exp + 127) << 23;
45 const uint32_t sat_man = (1u << 23) - 1u;
46 bits = sign | sat_exp | sat_man;
47 } else if (actual_exp < min_exp) {
48 bits &= 0x80000000u;
49 }
50 }
51 }
52
53 // RNE-round: zero low k bits with round-to-nearest-even bias.
54 const int k = 31 - exp_bits - mantissa_bits;
55 if (k > 0 && k < 32) {
56 const uint32_t mask = ~((1u << k) - 1u);
57 const uint32_t round_bias = (1u << (k - 1)) - 1u + ((bits >> k) & 1u);
58 bits = (bits + round_bias) & mask;
59 }
60 return std::bit_cast<float>(bits);
61 }
62
63} // namespace FPCompressionUtils
64} // namespace FlavorTagInference
65
66#endif // FLAVORTAGINFERENCE_FPCOMPRESSIONUTILS_H
int sign(int a)
float truncateToFloat(float val, int exp_bits, int mantissa_bits)
This file contains "getter" functions used for accessing tagger inputs from the EDM.