7#ifndef CALORECGPU_FPHELPERS_H
8#define CALORECGPU_FPHELPERS_H
10#ifndef CALORECGPU_INCLUDE_CUDA_SUPPORT
12 #define CALORECGPU_INCLUDE_CUDA_SUPPORT 1
23#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
25 #include "cuda_fp16.h"
27 #include "cuda_bf16.h"
31#if __cpp_lib_bitops || __cpp_lib_bit_cast
70 return std::countl_zero(num);
82 T probe = T(1) << (
sizeof(T) * CHAR_BIT - 1);
84 while ((num & probe) == 0 && probe)
94#define CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(ATTRIB, TYPE, BUILTIN) \
96 ATTRIB inline unsigned int count_leading_zeros(const TYPE num) \
100 return sizeof(TYPE) * CHAR_BIT; \
102 return BUILTIN(num); \
107#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
119#elif defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
135#if defined(__clang__)
145#elif defined(_MSC_VER)
195#undef CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER
206 const bool valid = amount <
sizeof(T) * CHAR_BIT;
207 return (
x << (amount * valid)) * valid;
213 const bool valid = amount <
sizeof(T) * CHAR_BIT;
214 return (
x >> (amount * valid)) * valid;
221 return (x1 - x2) * (x1 >= x2);
225 inline static constexpr T
min(
const T x1,
const T x2)
227 return (x1 > x2) * x2 + (x1 <= x2) * x1;
231 inline static constexpr T
max(
const T x1,
const T x2)
233 return (x1 > x2) * x1 + (x1 <= x2) * x2;
237 inline static constexpr T
clamp(
const T
x,
const T low,
const T high)
239 return low * (
x < low) + high * (
x > high) +
x * (
x >= low &&
x <= high);
245 inline static constexpr T
bit_and(
const T x1,
const T x2)
251 inline static constexpr T
bit_or(
const T x1,
const T x2)
261#if __cpp_lib_bit_cast
264 template <
class To,
class From>
265 constexpr inline static To
bitcast(
const From &
x)
267 return std::bit_cast<To, From>(
x);
274 template <
class To,
class From>
278 std::memcpy(&ret, &
x,
sizeof(To));
284#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
286#define CALORECGPU_CUDACAST_HELPER(TYPE_TO, TYPE_FROM, CONVFUNC) \
287 template <> __device__ constexpr inline \
288 TYPE_TO bitcast< TYPE_TO, TYPE_FROM >(const TYPE_FROM &x) \
290 return CONVFUNC (x); \
294 CALORECGPU_CUDACAST_HELPER( int64_t,
double, __double_as_longlong );
295 CALORECGPU_CUDACAST_HELPER( uint64_t,
double, __double_as_longlong );
296 CALORECGPU_CUDACAST_HELPER(
double, int64_t, __longlong_as_double );
297 CALORECGPU_CUDACAST_HELPER(
double, uint64_t, __longlong_as_double );
299 CALORECGPU_CUDACAST_HELPER( int32_t,
float, __float_as_int );
300 CALORECGPU_CUDACAST_HELPER( uint32_t,
float, __float_as_uint );
301 CALORECGPU_CUDACAST_HELPER( int64_t,
float, __float_as_uint );
302 CALORECGPU_CUDACAST_HELPER( uint64_t,
float, __float_as_uint );
303 CALORECGPU_CUDACAST_HELPER(
float, int32_t, __int_as_float );
304 CALORECGPU_CUDACAST_HELPER(
float, uint32_t, __uint_as_float );
305 CALORECGPU_CUDACAST_HELPER(
float, int64_t, __uint_as_float );
306 CALORECGPU_CUDACAST_HELPER(
float, uint64_t, __uint_as_float );
308 CALORECGPU_CUDACAST_HELPER( int16_t, __half, __half_as_short );
309 CALORECGPU_CUDACAST_HELPER( uint16_t, __half, __half_as_ushort );
310 CALORECGPU_CUDACAST_HELPER( int32_t, __half, __half_as_ushort );
311 CALORECGPU_CUDACAST_HELPER( uint32_t, __half, __half_as_ushort );
312 CALORECGPU_CUDACAST_HELPER( int64_t, __half, __half_as_ushort );
313 CALORECGPU_CUDACAST_HELPER( uint64_t, __half, __half_as_ushort );
314 CALORECGPU_CUDACAST_HELPER( __half, int16_t, __short_as_half );
315 CALORECGPU_CUDACAST_HELPER( __half, uint16_t, __ushort_as_half );
316 CALORECGPU_CUDACAST_HELPER( __half, int32_t, __ushort_as_half );
317 CALORECGPU_CUDACAST_HELPER( __half, uint32_t, __ushort_as_half );
318 CALORECGPU_CUDACAST_HELPER( __half, int64_t, __ushort_as_half );
319 CALORECGPU_CUDACAST_HELPER( __half, uint64_t, __ushort_as_half );
350 template <
unsigned int mantiss,
unsigned int exp,
unsigned int tag = 1>
struct IEEE754_like
353 static_assert(mantiss > 0 && exp > 0,
"The exponent and mantissa must contain a positive number of bits!");
357 return mantiss + exp + 1;
373 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
374 "The type must be large enough to hold the bit representation of the floating point." );
375 T ret = (T(1) << mantiss) - 1;
382 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
383 "The type must be large enough to hold the bit representation of the floating point." );
385 T ret = (T(1) << exp) - 1;
386 return ret << mantiss;
392 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
393 "The type must be large enough to hold the bit representation of the floating point." );
394 T ret = T(1) << (exp + mantiss);
407 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
408 "The type must be large enough to hold the bit representation of the floating point." );
409 return (T(1) << (exp - 1)) - 1;
415 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
416 "The type must be large enough to hold the bit representation of the floating point." );
427 constexpr inline static bool is_NaN(
const T pattern)
447 return pattern ^ xor_mask;
457 return pattern ^ xor_mask;
485 constexpr inline static bool round_results(
const bool is_negative,
const bool is_odd,
486 const bool is_nearer_to_up,
const bool is_tied,
499 return is_nearer_to_up || (is_odd && is_tied);
501 return is_nearer_to_up || is_tied;
516 constexpr unsigned int extra_bits = 2;
524 const bool a_denormal = (exp_a != 0);
525 const bool b_denormal = (exp_b != 0);
530 const T mantiss_a = ((
a &
mantissa_mask<T>()) | (first_not_mantissa_bit * a_denormal)) << extra_bits;
531 const T mantiss_b = ((b &
mantissa_mask<T>()) | (first_not_mantissa_bit * b_denormal)) << extra_bits;
534 T mantiss_ret = mantiss_a;
536 mantiss_ret += safe_rshift(mantiss_b, exp_a - exp_b);
538 mantiss_ret |= !!(safe_lshift(mantiss_b, exp_a - exp_b) &
mantissa_mask<T>()) * use_second;
541 constexpr unsigned int desired_number_of_zeros =
sizeof(T) * CHAR_BIT -
mantissa_size_bits() - 1 - extra_bits;
542 const unsigned int shift_amount = clamped_sub(desired_number_of_zeros, leading_zeros);
544 const T last_bit_mask = T(1) << (shift_amount + extra_bits);
545 const T last_discarded_bit_mask = last_bit_mask >> 1;
546 const T round_mask = (last_bit_mask - 1) * !!(last_bit_mask);
547 const bool round_up = (mantiss_ret & round_mask) > last_discarded_bit_mask;
548 const bool tied = last_discarded_bit_mask && ((mantiss_ret & round_mask) == last_discarded_bit_mask);
550 bool round_bit =
round_results<T>(is_negative, (mantiss_ret & last_bit_mask), round_up, tied, rt) && !!last_bit_mask;
552 mantiss_ret = safe_rshift(mantiss_ret, shift_amount + extra_bits);
554 mantiss_ret += round_bit * (shift_amount + extra_bits <=
sizeof(T) * CHAR_BIT);
556 const T exponent_ret = exp_a + shift_amount + (exp_a == 0 && mantiss_ret >
mantissa_mask<T>());
575 constexpr unsigned int extra_bits = 2;
586 const T mantiss_a = ((
a &
mantissa_mask<T>()) | (first_not_mantissa_bit * (exp_a != 0))) << extra_bits;
587 const T mantiss_b = ((b &
mantissa_mask<T>()) | (first_not_mantissa_bit * (exp_b != 0))) << extra_bits;
590 T mantiss_ret = mantiss_a;
592 mantiss_ret -= safe_rshift(mantiss_b, exp_a - exp_b) * use_second;
594 mantiss_ret |= !!(safe_lshift(-mantiss_b, exp_a - exp_b) &
mantissa_mask<T>()) * use_second;
597 constexpr unsigned int desired_number_of_zeros =
sizeof(T) * CHAR_BIT -
mantissa_size_bits() - 1 - extra_bits;
598 const unsigned int shift_amount = clamped_sub(leading_zeros, desired_number_of_zeros);
600 const T last_bit_mask = T(1) << extra_bits;
601 const T last_discarded_bit_mask = last_bit_mask >> 1;
602 const T round_mask = (last_bit_mask - 1) * !!(last_bit_mask);
603 const bool round_up = (mantiss_ret & round_mask) > last_discarded_bit_mask;
604 const bool tied = last_discarded_bit_mask && ((mantiss_ret & round_mask) == last_discarded_bit_mask);
606 bool round_bit =
round_results<T>(is_negative, (mantiss_ret & last_bit_mask), round_up, tied, rt) && !!last_bit_mask;
608 mantiss_ret >>= extra_bits;
610 mantiss_ret += round_bit;
612 mantiss_ret = safe_lshift(mantiss_ret, shift_amount);
614 const T exponent_ret = clamped_sub(exp_a, shift_amount);
616 mantiss_ret = safe_rshift(mantiss_ret, clamped_sub(shift_amount, exp_a));
650 if (sign_a == sign_b)
669 if (sign_a == sign_b)
686 else if (abs_a == abs_b)
706 template <
class FLarge,
class FSmall>
709 static_assert(FSmall::mantissa_size_bits() <= FLarge::mantissa_size_bits() &&
710 FSmall::exponent_size_bits() <= FLarge::exponent_size_bits() );
716 using FDest = FSmall;
717 using FSource = FLarge;
720 const bool sign_bit = pattern & FSource::template sign_mask<T>();
722 const T exponent = (pattern & FSource::template exponent_mask<T>()) >> FSource::mantissa_size_bits();
723 const T mantissa = pattern & FSource::template mantissa_mask<T>();
725 constexpr T delta_exponents = FSource::template exponent_bias<T>() - FDest::template exponent_bias<T>();
727 const bool exponent_full = (exponent > delta_exponents + FDest::template max_exponent_with_bias<T>());
728 const bool delete_mantissa = exponent_full && exponent <= FSource::template max_exponent_with_bias<T>();
731 const bool denormal = exponent <= delta_exponents;
732 const bool zero = exponent + FDest::mantissa_size_bits() <= delta_exponents;
734 const T final_exponent =
min(clamped_sub(exponent, delta_exponents), FDest::template max_exponent_with_bias<T>() + 1);
736 const T extra_mantissa_shift = clamped_sub(delta_exponents + 1, exponent) * denormal;
737 const T total_mantissa_shift = (FSource::mantissa_size_bits() - FDest::mantissa_size_bits()) + extra_mantissa_shift;
739 const T mantissa_keep_mask = safe_lshift(FDest::template mantissa_mask<T>(), total_mantissa_shift) & FSource::template mantissa_mask<T>();
740 const T check_for_rounding_mask = FSource::template mantissa_mask<T>() & (~mantissa_keep_mask);
741 const T last_mantissa_bit = safe_lshift(FDest::template mantissa_mask<T>(), total_mantissa_shift) & FSource::template mantissa_mask<T>();
742 const T first_discarded_bit_mask = last_mantissa_bit >> 1;
745 const T extra_denormal_bit = safe_rshift(T(denormal) << FDest::mantissa_size_bits(), extra_mantissa_shift);
747 const bool round_up = (mantissa & check_for_rounding_mask) > first_discarded_bit_mask;
748 const bool tie_break = ((mantissa & check_for_rounding_mask) == first_discarded_bit_mask);
750 const bool round_bit = FDest::template round_results<T>(sign_bit, mantissa & last_mantissa_bit, round_up, tie_break, rt) && !(exponent_full && !delete_mantissa && !denormal);
753 T final_mantissa = (safe_rshift(mantissa, total_mantissa_shift) | extra_denormal_bit) + round_bit;
755 final_mantissa *= !delete_mantissa;
757 return sign_bit * FDest::template sign_mask<T>() | ((final_exponent << FDest::mantissa_size_bits()) | final_mantissa) * !
zero;
764 using FDest = FLarge;
765 using FSource = FSmall;
768 const bool sign_bit = pattern & FSource::template sign_mask<T>();
769 T exponent = (pattern & FSource::template exponent_mask<T>()) >> FSource::mantissa_size_bits();
770 T mantissa = pattern & FSource::template mantissa_mask<T>();
772 constexpr T delta_exponents = (FDest::template exponent_bias<T>() - FSource::template exponent_bias<T>());
774 if (exponent == 0 && FDest::exponent_size_bits() > FSource::exponent_size_bits())
777 const unsigned int shift_amount = leading_zeros - (
sizeof(T) * CHAR_BIT - FSource::mantissa_size_bits()) + 1;
778 const unsigned int exponent_offset = (mantissa != 0) * (shift_amount - 1);
779 mantissa = (mantissa << shift_amount) & FSource::template mantissa_mask<T>();
781 exponent = delta_exponents - exponent_offset;
783 else if (exponent > FSource::template max_exponent_with_bias<T>())
786 exponent = FDest::template max_exponent_with_bias<T>() + 1;
790 exponent = exponent + delta_exponents;
794 return sign_bit * FDest::template sign_mask<T>() | (exponent << FDest::mantissa_size_bits()) | (mantissa << (FDest::mantissa_size_bits() - FSource::mantissa_size_bits()));
806 template <
class T,
class FDest,
class FSource>
809 static_assert(FDest::mantissa_size_bits() <= FSource::mantissa_size_bits() &&
810 FDest::exponent_size_bits() <= FSource::exponent_size_bits(),
811 "The destination type must not be a larger floating point type than the source one.");
822 template <
class T,
class FDest,
class FSource>
825 static_assert(FDest::mantissa_size_bits() >= FSource::mantissa_size_bits() &&
826 FDest::exponent_size_bits() >= FSource::exponent_size_bits(),
827 "The source type must not be a larger floating point type than the destination one.");
836 template <>
template<
class T>
842 float float_ret = float_a + float_b;
847 template <>
template<
class T>
853 double double_ret = double_a + double_b;
858 template <>
template<
class T>
864 float float_ret = float_a - float_b;
869 template <>
template<
class T>
875 double double_ret = double_a - double_b;
900 template<
class Format>
918#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
923 template <>
template<
class T>
929 __half conv_ret = __hadd(conv_a, conv_b);
934 template <>
template<
class T>
940 __half conv_ret = __hsub(conv_a, conv_b);
945 template <>
template<
class T>
951 __nv_bfloat16 conv_ret = __hadd(conv_a, conv_b);
956 template <>
template<
class T>
962 __nv_bfloat16 conv_ret = __hsub(conv_a, conv_b);
974 const float fl = __half2float(hf);
985 ret = __float2half_ru(pre_conv);
988 ret = __float2half_rd(pre_conv);
991 ret = __float2half_rz(pre_conv);
994 ret = __float2half_rn(pre_conv);
999 ret = __float2half(pre_conv);
1014 const float fl = __bfloat162float(hf);
1025 ret = __float2bfloat16_ru(pre_conv);
1028 ret = __float2bfloat16_rd(pre_conv);
1031 ret = __float2bfloat16_rz(pre_conv);
1034 ret = __float2bfloat16_rn(pre_conv);
1039 ret = __float2bfloat16(pre_conv)
1052 const T first_step = ConversionHelper<StandardFloat, CUDAHalfFloat>::template
up_convert<T>(pattern, rt);
1053 return ConversionHelper<StandardDouble, StandardFloat>::template
up_convert<T>(first_step, rt);
1065 const T first_step = ConversionHelper<StandardDouble, StandardFloat>::template
down_convert<T>(first_step, rt);
1066 return ConversionHelper<StandardFloat, CUDAHalfFloat>::template
down_convert<T>(first_step, rt);
1077 const T first_step = ConversionHelper<StandardFloat, CUDABFloat16>::template
up_convert<T>(first_step, rt);
1078 return ConversionHelper<StandardDouble, StandardFloat>::template
up_convert<T>(first_step, rt);
1090 const T first_step = ConversionHelper<StandardDouble, StandardFloat>::template
down_convert<T>(first_step, rt);
1091 return ConversionHelper<StandardFloat, CUDABFloat16>::template
down_convert<T>(first_step, rt);
#define CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(ATTRIB, TYPE, BUILTIN)
void zero(TH2 *h)
zero the contents of a 2d histogram
static constexpr To bitcast(const From &x)
static constexpr unsigned int count_leading_zeros(const T num)
static constexpr T bit_or(const T x1, const T x2)
static constexpr T safe_lshift(const T x, const T amount)
static constexpr T bit_and(const T x1, const T x2)
static constexpr T clamp(const T x, const T low, const T high)
static constexpr T clamped_sub(const T x1, const T x2)
static constexpr T safe_rshift(const T x, const T amount)
IEEE754_like< 52, 11 > StandardDouble
IEEE754_like< 7, 8 > CUDABFloat16
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Converts pattern from the larger floating point format FSource to FDest.
IEEE754_like< 23, 8 > StandardFloat
RoundingModes
Specifies the rounding mode to use for the operations.
IEEE754_like< 10, 5 > CUDAHalfFloat
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Converts pattern from the smaller floating point format FSource to FDest.
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Specifies a floating point format like those described in IEEE-754, with an adjustable number of bits...
static constexpr unsigned int mantissa_size_bits()
static constexpr T add_patterns(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
The absolute value of must be greater than or equal than that of .
static constexpr bool is_NaN(const T pattern)
static constexpr T exponent_bias()
static constexpr T positive_infinity()
static constexpr T exponent_mask()
static constexpr T add(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
static constexpr T subtract(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
static constexpr T negative_infinity()
static constexpr T negative_zero()
static constexpr T from_total_ordering(const T pattern)
static constexpr T full_mask()
static constexpr T to_total_ordering(const T pattern)
static constexpr unsigned int exponent_size_bits()
static constexpr T max_exponent_with_bias()
static constexpr T mantissa_mask()
static constexpr bool round_results(const bool is_negative, const bool is_odd, const bool is_nearer_to_up, const bool is_tied, RoundingModes rt)
static constexpr T positive_zero()
static constexpr T absolute_value(const T pattern)
static constexpr T sign_mask()
static constexpr bool is_infinite(const T pattern)
static constexpr T subtract_patterns(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
The absolute value of must be greater than or equal than that of .
static constexpr unsigned int total_size_bits()