7 #ifndef CALORECGPU_FPHELPERS_H
8 #define CALORECGPU_FPHELPERS_H
10 #ifndef CALORECGPU_INCLUDE_CUDA_SUPPORT
12 #define CALORECGPU_INCLUDE_CUDA_SUPPORT 1
23 #if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
25 #include "cuda_fp16.h"
27 #include "cuda_bf16.h"
31 #if __cpp_lib_bitops || __cpp_lib_bit_cast
62 namespace LeadingZerosPortability
68 inline static constexpr
unsigned int count_leading_zeros(
const T
num)
70 return std::countl_zero(
num);
76 inline static constexpr
unsigned int count_leading_zeros(
const T
num)
82 T probe = T(1) << (
sizeof(T) * CHAR_BIT - 1);
84 while ((
num & probe) == 0 && probe)
94 #define CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(ATTRIB, TYPE, BUILTIN) \
96 ATTRIB inline unsigned int count_leading_zeros(const TYPE num) \
100 return sizeof(TYPE) * CHAR_BIT; \
102 return BUILTIN(num); \
107 #if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
119 #elif defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
135 #if defined(__clang__)
145 #elif defined(_MSC_VER)
155 namespace LeadingZerosPortability
175 inline unsigned long long int count_leading_zeros<unsigned long long int>(
unsigned long long int T num)
177 const auto res_1 = count_leading_zeros<unsigned int>(
num >> 32);
178 return res_1 + (res_1 == 32) * count_leading_zeros<unsigned int>(
num);
182 inline long long int count_leading_zeros<long long int>(
long long int T
num)
184 return count_leading_zeros<unsigned long long int>(
num);
195 #undef CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER
199 namespace OperatorsHelper
204 inline static constexpr T safe_lshift(
const T
x,
const T amount)
206 const bool valid = amount <
sizeof(T) * CHAR_BIT;
211 inline static constexpr T safe_rshift(
const T
x,
const T amount)
213 const bool valid = amount <
sizeof(T) * CHAR_BIT;
219 inline static constexpr T clamped_sub(
const T
x1,
const T
x2)
225 inline static constexpr T
min(
const T
x1,
const T
x2)
231 inline static constexpr T
max(
const T
x1,
const T
x2)
237 inline static constexpr T clamp(
const T
x,
const T low,
const T high)
239 return low * (
x < low) + high * (
x > high) +
x * (
x >= low &&
x <= high);
245 inline static constexpr T bit_and(
const T
x1,
const T
x2)
251 inline static constexpr T bit_or(
const T
x1,
const T
x2)
258 namespace BitCastHelper
261 #if __cpp_lib_bit_cast
264 template <
class To,
class From>
265 constexpr
inline static To bitcast(
const From &
x)
267 return std::bit_cast<To, From>(
x);
274 template <
class To,
class From>
275 constexpr
inline static To bitcast(
const From &
x)
278 std::memcpy(&ret, &
x,
sizeof(To));
284 #if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
286 #define CALORECGPU_CUDACAST_HELPER(TYPE_TO, TYPE_FROM, CONVFUNC) \
287 template <> __device__ constexpr inline \
288 TYPE_TO bitcast< TYPE_TO, TYPE_FROM >(const TYPE_FROM &x) \
290 return CONVFUNC (x); \
294 CALORECGPU_CUDACAST_HELPER( int64_t,
double, __double_as_longlong );
295 CALORECGPU_CUDACAST_HELPER(
uint64_t,
double, __double_as_longlong );
296 CALORECGPU_CUDACAST_HELPER(
double, int64_t, __longlong_as_double );
297 CALORECGPU_CUDACAST_HELPER(
double,
uint64_t, __longlong_as_double );
299 CALORECGPU_CUDACAST_HELPER( int32_t,
float, __float_as_int );
300 CALORECGPU_CUDACAST_HELPER(
uint32_t,
float, __float_as_uint );
301 CALORECGPU_CUDACAST_HELPER( int64_t,
float, __float_as_uint );
302 CALORECGPU_CUDACAST_HELPER(
uint64_t,
float, __float_as_uint );
303 CALORECGPU_CUDACAST_HELPER(
float, int32_t, __int_as_float );
304 CALORECGPU_CUDACAST_HELPER(
float,
uint32_t, __uint_as_float );
305 CALORECGPU_CUDACAST_HELPER(
float, int64_t, __uint_as_float );
306 CALORECGPU_CUDACAST_HELPER(
float,
uint64_t, __uint_as_float );
308 CALORECGPU_CUDACAST_HELPER(
int16_t, __half, __half_as_short );
309 CALORECGPU_CUDACAST_HELPER(
uint16_t, __half, __half_as_ushort );
310 CALORECGPU_CUDACAST_HELPER( int32_t, __half, __half_as_ushort );
311 CALORECGPU_CUDACAST_HELPER(
uint32_t, __half, __half_as_ushort );
312 CALORECGPU_CUDACAST_HELPER( int64_t, __half, __half_as_ushort );
313 CALORECGPU_CUDACAST_HELPER(
uint64_t, __half, __half_as_ushort );
314 CALORECGPU_CUDACAST_HELPER( __half,
int16_t, __short_as_half );
315 CALORECGPU_CUDACAST_HELPER( __half,
uint16_t, __ushort_as_half );
316 CALORECGPU_CUDACAST_HELPER( __half, int32_t, __ushort_as_half );
317 CALORECGPU_CUDACAST_HELPER( __half,
uint32_t, __ushort_as_half );
318 CALORECGPU_CUDACAST_HELPER( __half, int64_t, __ushort_as_half );
319 CALORECGPU_CUDACAST_HELPER( __half,
uint64_t, __ushort_as_half );
350 template <
unsigned int mantiss,
unsigned int exp,
unsigned int tag = 1>
struct IEEE754_like
353 static_assert(mantiss > 0 &&
exp > 0,
"The exponent and mantissa must contain a positive number of bits!");
357 return mantiss +
exp + 1;
373 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss +
exp + 1),
374 "The type must be large enough to hold the bit representation of the floating point." );
375 T ret = (T(1) << mantiss) - 1;
382 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss +
exp + 1),
383 "The type must be large enough to hold the bit representation of the floating point." );
385 T ret = (T(1) <<
exp) - 1;
386 return ret << mantiss;
392 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss +
exp + 1),
393 "The type must be large enough to hold the bit representation of the floating point." );
394 T ret = T(1) << (
exp + mantiss);
401 return mantissa_mask<T>() | exponent_mask<T>() | sign_mask<T>();
407 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss +
exp + 1),
408 "The type must be large enough to hold the bit representation of the floating point." );
409 return (T(1) << (
exp - 1)) - 1;
415 static_assert(
sizeof(T) * CHAR_BIT >= (mantiss +
exp + 1),
416 "The type must be large enough to hold the bit representation of the floating point." );
417 return exponent_bias<T>() * 2;
423 return (
pattern & (~sign_mask<T>())) == exponent_mask<T>();
429 return (
pattern & (~sign_mask<T>())) > exponent_mask<T>();
437 return pattern & (~sign_mask<T>());
446 const T xor_mask = (!!(
pattern & sign_mask<T>()) * full_mask<T>()) | sign_mask<T>();
456 const T xor_mask = (!(
pattern & sign_mask<T>()) * full_mask<T>()) | sign_mask<T>();
469 return sign_mask<T>();
475 return exponent_mask<T>();
481 return sign_mask<T>() | exponent_mask<T>();
485 constexpr
inline static bool round_results(
const bool is_negative,
const bool is_odd,
486 const bool is_nearer_to_up,
const bool is_tied,
499 return is_nearer_to_up || (is_odd && is_tied);
501 return is_nearer_to_up || is_tied;
514 using namespace OperatorsHelper;
516 constexpr
unsigned int extra_bits = 2;
524 const bool a_denormal = (exp_a != 0);
525 const bool b_denormal = (exp_b != 0);
528 const bool is_negative =
a & sign_mask<T>();
530 const T mantiss_a = ((
a & mantissa_mask<T>()) | (first_not_mantissa_bit * a_denormal)) << extra_bits;
531 const T mantiss_b = ((
b & mantissa_mask<T>()) | (first_not_mantissa_bit * b_denormal)) << extra_bits;
534 T mantiss_ret = mantiss_a;
536 mantiss_ret += safe_rshift(mantiss_b, exp_a - exp_b);
538 mantiss_ret |= !!(safe_lshift(mantiss_b, exp_a - exp_b) & mantissa_mask<T>()) * use_second;
540 const unsigned int leading_zeros = LeadingZerosPortability::count_leading_zeros<T>(mantiss_ret);
541 constexpr
unsigned int desired_number_of_zeros =
sizeof(T) * CHAR_BIT -
mantissa_size_bits() - 1 - extra_bits;
542 const unsigned int shift_amount = clamped_sub(desired_number_of_zeros, leading_zeros);
544 const T last_bit_mask = T(1) << (shift_amount + extra_bits);
545 const T last_discarded_bit_mask = last_bit_mask >> 1;
546 const T round_mask = (last_bit_mask - 1) * !!(last_bit_mask);
547 const bool round_up = (mantiss_ret & round_mask) > last_discarded_bit_mask;
548 const bool tied = last_discarded_bit_mask && ((mantiss_ret & round_mask) == last_discarded_bit_mask);
550 bool round_bit = round_results<T>(is_negative, (mantiss_ret & last_bit_mask), round_up, tied, rt) && !!last_bit_mask;
552 mantiss_ret = safe_rshift(mantiss_ret, shift_amount + extra_bits);
554 mantiss_ret += round_bit * (shift_amount + extra_bits <=
sizeof(T) * CHAR_BIT);
556 const T exponent_ret = exp_a + shift_amount + (exp_a == 0 && mantiss_ret > mantissa_mask<T>());
558 mantiss_ret &= mantissa_mask<T>();
560 mantiss_ret &= ~( ( exponent_ret > max_exponent_with_bias<T>() ) * mantissa_mask<T>() );
564 return (is_negative * sign_mask<T>()) | (exponent_ret <<
mantissa_size_bits()) | mantiss_ret;
573 using namespace OperatorsHelper;
575 constexpr
unsigned int extra_bits = 2;
584 const bool is_negative =
a & sign_mask<T>();
586 const T mantiss_a = ((
a & mantissa_mask<T>()) | (first_not_mantissa_bit * (exp_a != 0))) << extra_bits;
587 const T mantiss_b = ((
b & mantissa_mask<T>()) | (first_not_mantissa_bit * (exp_b != 0))) << extra_bits;
590 T mantiss_ret = mantiss_a;
592 mantiss_ret -= safe_rshift(mantiss_b, exp_a - exp_b) * use_second;
594 mantiss_ret |= !!(safe_lshift(-mantiss_b, exp_a - exp_b) & mantissa_mask<T>()) * use_second;
596 const unsigned int leading_zeros = LeadingZerosPortability::count_leading_zeros<T>(mantiss_ret);
597 constexpr
unsigned int desired_number_of_zeros =
sizeof(T) * CHAR_BIT -
mantissa_size_bits() - 1 - extra_bits;
598 const unsigned int shift_amount = clamped_sub(leading_zeros, desired_number_of_zeros);
600 const T last_bit_mask = T(1) << extra_bits;
601 const T last_discarded_bit_mask = last_bit_mask >> 1;
602 const T round_mask = (last_bit_mask - 1) * !!(last_bit_mask);
603 const bool round_up = (mantiss_ret & round_mask) > last_discarded_bit_mask;
604 const bool tied = last_discarded_bit_mask && ((mantiss_ret & round_mask) == last_discarded_bit_mask);
606 bool round_bit = round_results<T>(is_negative, (mantiss_ret & last_bit_mask), round_up, tied, rt) && !!last_bit_mask;
608 mantiss_ret >>= extra_bits;
610 mantiss_ret += round_bit;
612 mantiss_ret = safe_lshift(mantiss_ret, shift_amount);
614 const T exponent_ret = clamped_sub(exp_a, shift_amount);
616 mantiss_ret = safe_rshift(mantiss_ret, clamped_sub(shift_amount, exp_a));
618 mantiss_ret &= mantissa_mask<T>();
620 return (is_negative * sign_mask<T>()) | (exponent_ret <<
mantissa_size_bits()) | mantiss_ret;
633 const T abs_a = absolute_value<T>(
a);
634 const T abs_b = absolute_value<T>(
b);
636 const bool sign_a =
a & sign_mask<T>();
637 const bool sign_b =
b & sign_mask<T>();
648 if (is_infinite<T>(
a) && is_infinite<T>(
b))
650 if (sign_a == sign_b)
660 else if (is_NaN<T>(
a))
664 else if (is_NaN<T>(
b))
669 if (sign_a == sign_b)
673 return add_patterns<T>(
a,
b, rt);
677 return add_patterns<T>(
b,
a, rt);
684 return (sign_a * sign_mask<T>()) | subtract_patterns<T>(abs_a, abs_b, rt);
686 else if (abs_a == abs_b)
692 return (sign_b * sign_mask<T>()) | subtract_patterns<T>(abs_b, abs_a, rt);
701 return add<T>(
a,
b ^ sign_mask<T>(), rt);
706 template <
class FLarge,
class FSmall>
709 static_assert(FSmall::mantissa_size_bits() <= FLarge::mantissa_size_bits() &&
710 FSmall::exponent_size_bits() <= FLarge::exponent_size_bits() );
716 using FDest = FSmall;
717 using FSource = FLarge;
718 using namespace OperatorsHelper;
727 const bool exponent_full = (exponent > delta_exponents +
FDest::template max_exponent_with_bias<T>());
728 const bool delete_mantissa = exponent_full && exponent <= FSource::template max_exponent_with_bias<T>();
731 const bool denormal = exponent <= delta_exponents;
732 const bool zero = exponent + FDest::mantissa_size_bits() <= delta_exponents;
734 const T final_exponent =
min(clamped_sub(exponent, delta_exponents),
FDest::template max_exponent_with_bias<T>() + 1);
736 const T extra_mantissa_shift = clamped_sub(delta_exponents + 1, exponent) * denormal;
737 const T total_mantissa_shift = (FSource::mantissa_size_bits() - FDest::mantissa_size_bits()) + extra_mantissa_shift;
740 const T check_for_rounding_mask =
FSource::template mantissa_mask<T>() & (~mantissa_keep_mask);
742 const T first_discarded_bit_mask = last_mantissa_bit >> 1;
745 const T extra_denormal_bit = safe_rshift(T(denormal) << FDest::mantissa_size_bits(), extra_mantissa_shift);
747 const bool round_up = (mantissa & check_for_rounding_mask) > first_discarded_bit_mask;
748 const bool tie_break = ((mantissa & check_for_rounding_mask) == first_discarded_bit_mask);
750 const bool round_bit =
FDest::template round_results<T>(sign_bit, mantissa & last_mantissa_bit, round_up, tie_break, rt) && !(exponent_full && !delete_mantissa && !denormal);
753 T final_mantissa = (safe_rshift(mantissa, total_mantissa_shift) | extra_denormal_bit) + round_bit;
755 final_mantissa *= !delete_mantissa;
757 return sign_bit *
FDest::template sign_mask<T>() | ((final_exponent << FDest::mantissa_size_bits()) | final_mantissa) * !
zero;
764 using FDest = FLarge;
765 using FSource = FSmall;
766 using namespace OperatorsHelper;
774 if (exponent == 0 && FDest::exponent_size_bits() > FSource::exponent_size_bits())
776 const unsigned int leading_zeros = LeadingZerosPortability::count_leading_zeros<T>(mantissa);
777 const unsigned int shift_amount = leading_zeros - (
sizeof(T) * CHAR_BIT - FSource::mantissa_size_bits()) + 1;
778 const unsigned int exponent_offset = (mantissa != 0) * (shift_amount - 1);
779 mantissa = (mantissa << shift_amount) & FSource::template mantissa_mask<T>();
781 exponent = delta_exponents - exponent_offset;
790 exponent = exponent + delta_exponents;
794 return sign_bit *
FDest::template sign_mask<T>() | (exponent << FDest::mantissa_size_bits()) | (mantissa << (FDest::mantissa_size_bits() - FSource::mantissa_size_bits()));
806 template <
class T,
class FDest,
class FSource>
809 static_assert(FDest::mantissa_size_bits() <= FSource::mantissa_size_bits() &&
810 FDest::exponent_size_bits() <= FSource::exponent_size_bits(),
811 "The destination type must not be a larger floating point type than the source one.");
822 template <
class T,
class FDest,
class FSource>
825 static_assert(FDest::mantissa_size_bits() >= FSource::mantissa_size_bits() &&
826 FDest::exponent_size_bits() >= FSource::exponent_size_bits(),
827 "The source type must not be a larger floating point type than the destination one.");
836 template <>
template<
class T>
839 const float float_a = BitCastHelper::bitcast<float, T>(
a);
840 const float float_b = BitCastHelper::bitcast<float, T>(
b);
842 float float_ret = float_a + float_b;
844 return BitCastHelper::bitcast<uint32_t, float>(float_ret);
847 template <>
template<
class T>
850 const double double_a = BitCastHelper::bitcast<double, T>(
a);
851 const double double_b = BitCastHelper::bitcast<double, T>(
b);
853 double double_ret = double_a + double_b;
855 return BitCastHelper::bitcast<uint64_t, double>(double_ret);
858 template <>
template<
class T>
861 const float float_a = BitCastHelper::bitcast<float, T>(
a);
862 const float float_b = BitCastHelper::bitcast<float, T>(
b);
864 float float_ret = float_a - float_b;
866 return BitCastHelper::bitcast<uint32_t, float>(float_ret);
869 template <>
template<
class T>
872 const double double_a = BitCastHelper::bitcast<double, T>(
a);
873 const double double_b = BitCastHelper::bitcast<double, T>(
b);
875 double double_ret = double_a - double_b;
877 return BitCastHelper::bitcast<uint64_t, double>(double_ret);
886 const float f = BitCastHelper::bitcast<float, uint32_t>(
pattern);
888 return BitCastHelper::bitcast<T, double>(
d);
893 const double d = BitCastHelper::bitcast<double, uint64_t>(
pattern);
895 return BitCastHelper::bitcast<uint32_t, float>(
f);
900 template<
class Format>
918 #if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
923 template <>
template<
class T>
926 const __half conv_a = BitCastHelper::bitcast<__half, T>(
a);
927 const __half conv_b = BitCastHelper::bitcast<__half, T>(
b);
929 __half conv_ret = __hadd(conv_a, conv_b);
931 return BitCastHelper::bitcast<uint16_t, __half>(conv_ret);
934 template <>
template<
class T>
937 const __nv_bfloat16 conv_a = BitCastHelper::bitcast<__half, T>(
a);
938 const __nv_bfloat16 conv_b = BitCastHelper::bitcast<__half, T>(
b);
940 __half conv_ret = __hsub(conv_a, conv_b);
942 return BitCastHelper::bitcast<uint16_t, __half>(conv_ret);
945 template <>
template<
class T>
948 const __nv_bfloat16 conv_a = BitCastHelper::bitcast<__nv_bfloat16, T>(
a);
949 const __nv_bfloat16 conv_b = BitCastHelper::bitcast<__nv_bfloat16, T>(
b);
951 __nv_bfloat16 conv_ret = __hadd(conv_a, conv_b);
953 return BitCastHelper::bitcast<uint16_t, __nv_bfloat16>(conv_ret);
956 template <>
template<
class T>
959 const __nv_bfloat16 conv_a = BitCastHelper::bitcast<__nv_bfloat16, T>(
a);
960 const __nv_bfloat16 conv_b = BitCastHelper::bitcast<__nv_bfloat16, T>(
b);
962 __nv_bfloat16 conv_ret = __hsub(conv_a, conv_b);
964 return BitCastHelper::bitcast<uint16_t, __nv_bfloat16>(conv_ret);
973 const __half hf = BitCastHelper::bitcast<__half, T>(
pattern);
974 const float fl = __half2float(hf);
975 return BitCastHelper::bitcast<T, float>(fl);
980 const float pre_conv = BitCastHelper::bitcast<float, T>(
pattern);
985 ret = __float2half_ru(pre_conv);
988 ret = __float2half_rd(pre_conv);
991 ret = __float2half_rz(pre_conv);
994 ret = __float2half_rn(pre_conv);
999 ret = __float2half(pre_conv);
1002 return BitCastHelper::bitcast<T, __half>(ret);
1013 const __nv_bfloat16 hf = BitCastHelper::bitcast<__nv_bfloat16, T>(
pattern);
1014 const float fl = __bfloat162float(hf);
1015 return BitCastHelper::bitcast<T, float>(fl);
1020 const float pre_conv = BitCastHelper::bitcast<float, T>(
pattern);
1025 ret = __float2bfloat16_ru(pre_conv);
1028 ret = __float2bfloat16_rd(pre_conv);
1031 ret = __float2bfloat16_rz(pre_conv);
1034 ret = __float2bfloat16_rn(pre_conv);
1039 ret = __float2bfloat16(pre_conv)
1042 return BitCastHelper::bitcast<T, __nv_bfloat16>(ret);
1060 const double d = BitCastHelper::bitcast<double, uint64_t>(
pattern);
1061 return BitCastHelper::bitcast<T, __half>(__double2half(
d));
1085 const double d = BitCastHelper::bitcast<double, uint64_t>(
pattern);
1086 return BitCastHelper::bitcast<T, __nv_bfloat16>(__double2bfloat16(
d));