ATLAS Offline Software
Loading...
Searching...
No Matches
FPHelpers.h
Go to the documentation of this file.
1//
2// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3//
4// Dear emacs, this is -*- c++ -*-
5//
6
7#ifndef CALORECGPU_FPHELPERS_H
8#define CALORECGPU_FPHELPERS_H
9
10#ifndef CALORECGPU_INCLUDE_CUDA_SUPPORT
11
12 #define CALORECGPU_INCLUDE_CUDA_SUPPORT 1
13
14 //If CUDA is available, we will support its native floating point operations.
15 //We can disable this by defining CALORECGPU_INCLUDE_CUDA_SUPPORT as 0...
16
17#endif
18
19#include <cstdint>
20#include <climits>
21#include <cstring>
22
23#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
24
25 #include "cuda_fp16.h"
26
27 #include "cuda_bf16.h"
28
29#endif
30
31#if __cpp_lib_bitops || __cpp_lib_bit_cast
32
33 #include <bit>
34
35#endif
36
37
41
42//In its current form, it is really only used
43//to provide to_total_ordering for floats
44//used in the GPU. For a while (before we
45//came up with the current tag assignment
46//for the splitter), it provided us with
47//several utilities to emulate less precise
48//floating point numbers so we could
49//squash the energy of the tags...
50
52{
53
61
63 {
64
65#if __cpp_lib_bitops
66
67 template <class T>
68 inline static constexpr unsigned int count_leading_zeros(const T num)
69 {
70 return std::countl_zero(num);
71 }
72
73#else
74
75 template <class T>
76 inline static constexpr unsigned int count_leading_zeros(const T num)
77 //I know this could be greatly optimized.
78 //The point is, either pray for the compiler's smartness
79 //or replace this with a non-portable built-in
80 //whenever / wherever necessary...
81 {
82 T probe = T(1) << (sizeof(T) * CHAR_BIT - 1);
83 unsigned int ret = 0;
84 while ((num & probe) == 0 && probe)
85 {
86 ++ret;
87 probe >>= 1;
88 }
89 return ret;
90 }
91
92#endif
93
94#define CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(ATTRIB, TYPE, BUILTIN) \
95 template<> \
96 ATTRIB inline unsigned int count_leading_zeros(const TYPE num) \
97 { \
98 if (!num) \
99 { \
100 return sizeof(TYPE) * CHAR_BIT; \
101 } \
102 return BUILTIN(num); \
103 } \
104
105
106
107#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
108
109
111
112 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(__device__, unsigned int, __clz)
113
114 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(__device__, long long, __clzll)
115
116 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(__device__, unsigned long long, __clzll)
117
118
119#elif defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
120
121
122 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, int, __builtin_clz)
123
124 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, unsigned int, __builtin_clz)
125
126 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, long, __builtin_clzl)
127
128 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, unsigned long, __builtin_clzl)
129
130 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, long long, __builtin_clzll)
131
132 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, unsigned long long, __builtin_clzll)
133
134
135#if defined(__clang__)
136
137
138 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, short, __builtin_clzs)
139
140 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(constexpr, unsigned short, __builtin_clzs)
141
142#endif
143
144
145#elif defined(_MSC_VER)
146
147
148 }
149}
150
151#include <intrin.h>
152
153namespace FloatingPointHelpers
154{
156 {
157
158 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(, unsigned short, __lzcnt16)
159
161
163
165
166#if defined(_WIN64)
167
168 //__lzcnt64 is only available in 64 bit, I think?
169 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(, unsigned long long int, __lzcnt64)
170
171 CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(, long long int, __lzcnt64)
172
173#else
174 template <>
175 inline unsigned long long int count_leading_zeros<unsigned long long int>(unsigned long long int T num)
176 {
177 const auto res_1 = count_leading_zeros<unsigned int>(num >> 32);
178 return res_1 + (res_1 == 32) * count_leading_zeros<unsigned int>(num);
179 }
180
181 template <>
182 inline long long int count_leading_zeros<long long int>(long long int T num)
183 {
185 }
186
187#endif
188
189
190#endif
191
192 //We could add more compilers here if needed,
193 //but the "big three" and CUDA should already be covered.
194
195#undef CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER
196
197 }
198
200 {
201 //Left and right shifts larger than the variable size are UB.
202
203 template <class T>
204 inline static constexpr T safe_lshift(const T x, const T amount)
205 {
206 const bool valid = amount < sizeof(T) * CHAR_BIT;
207 return (x << (amount * valid)) * valid;
208 }
209
210 template <class T>
211 inline static constexpr T safe_rshift(const T x, const T amount)
212 {
213 const bool valid = amount < sizeof(T) * CHAR_BIT;
214 return (x >> (amount * valid)) * valid;
215 }
216
217 //To prevent underflow for unsigned variables
218 template <class T>
219 inline static constexpr T clamped_sub(const T x1, const T x2)
220 {
221 return (x1 - x2) * (x1 >= x2);
222 }
223
224 template <class T>
225 inline static constexpr T min(const T x1, const T x2)
226 {
227 return (x1 > x2) * x2 + (x1 <= x2) * x1;
228 }
229
230 template <class T>
231 inline static constexpr T max(const T x1, const T x2)
232 {
233 return (x1 > x2) * x1 + (x1 <= x2) * x2;
234 }
235
236 template <class T>
237 inline static constexpr T clamp(const T x, const T low, const T high)
238 {
239 return low * (x < low) + high * (x > high) + x * (x >= low && x <= high);
240 }
241
242 //Just for occasional clarity with the arguments.
243
244 template <class T>
245 inline static constexpr T bit_and(const T x1, const T x2)
246 {
247 return x1 & x2;
248 }
249
250 template <class T>
251 inline static constexpr T bit_or(const T x1, const T x2)
252 {
253 return x1 | x2;
254 }
255
256 }
257
259 {
260
261#if __cpp_lib_bit_cast
262
263
264 template <class To, class From>
265 constexpr inline static To bitcast(const From & x)
266 {
267 return std::bit_cast<To, From>(x);
268 }
269
270#else
271
272 //The disadvantage here is that this won't be actually constexpr due to memcpy...
273
274 template <class To, class From>
275 constexpr inline static To bitcast(const From & x)
276 {
277 To ret = 0;
278 std::memcpy(&ret, &x, sizeof(To));
279 return ret;
280 }
281
282#endif
283
284#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
285
286#define CALORECGPU_CUDACAST_HELPER(TYPE_TO, TYPE_FROM, CONVFUNC) \
287 template <> __device__ constexpr inline \
288 TYPE_TO bitcast< TYPE_TO, TYPE_FROM >(const TYPE_FROM &x) \
289 { \
290 return CONVFUNC (x); \
291 } \
292
293
294 CALORECGPU_CUDACAST_HELPER( int64_t, double, __double_as_longlong );
295 CALORECGPU_CUDACAST_HELPER( uint64_t, double, __double_as_longlong );
296 CALORECGPU_CUDACAST_HELPER( double, int64_t, __longlong_as_double );
297 CALORECGPU_CUDACAST_HELPER( double, uint64_t, __longlong_as_double );
298
299 CALORECGPU_CUDACAST_HELPER( int32_t, float, __float_as_int );
300 CALORECGPU_CUDACAST_HELPER( uint32_t, float, __float_as_uint );
301 CALORECGPU_CUDACAST_HELPER( int64_t, float, __float_as_uint );
302 CALORECGPU_CUDACAST_HELPER( uint64_t, float, __float_as_uint );
303 CALORECGPU_CUDACAST_HELPER( float, int32_t, __int_as_float );
304 CALORECGPU_CUDACAST_HELPER( float, uint32_t, __uint_as_float );
305 CALORECGPU_CUDACAST_HELPER( float, int64_t, __uint_as_float );
306 CALORECGPU_CUDACAST_HELPER( float, uint64_t, __uint_as_float );
307
308 CALORECGPU_CUDACAST_HELPER( int16_t, __half, __half_as_short );
309 CALORECGPU_CUDACAST_HELPER( uint16_t, __half, __half_as_ushort );
310 CALORECGPU_CUDACAST_HELPER( int32_t, __half, __half_as_ushort );
311 CALORECGPU_CUDACAST_HELPER( uint32_t, __half, __half_as_ushort );
312 CALORECGPU_CUDACAST_HELPER( int64_t, __half, __half_as_ushort );
313 CALORECGPU_CUDACAST_HELPER( uint64_t, __half, __half_as_ushort );
314 CALORECGPU_CUDACAST_HELPER( __half, int16_t, __short_as_half );
315 CALORECGPU_CUDACAST_HELPER( __half, uint16_t, __ushort_as_half );
316 CALORECGPU_CUDACAST_HELPER( __half, int32_t, __ushort_as_half );
317 CALORECGPU_CUDACAST_HELPER( __half, uint32_t, __ushort_as_half );
318 CALORECGPU_CUDACAST_HELPER( __half, int64_t, __ushort_as_half );
319 CALORECGPU_CUDACAST_HELPER( __half, uint64_t, __ushort_as_half );
320 /*
321 CALORECGPU_CUDACAST_HELPER( int16_t, __nv_bfloat16, __bfloat16_as_short );
322 CALORECGPU_CUDACAST_HELPER( uint16_t, __nv_bfloat16, __bfloat16_as_ushort );
323 CALORECGPU_CUDACAST_HELPER( int32_t, __nv_bfloat16, __bfloat16_as_ushort );
324 CALORECGPU_CUDACAST_HELPER( uint32_t, __nv_bfloat16, __bfloat16_as_ushort );
325 CALORECGPU_CUDACAST_HELPER( int64_t, __nv_bfloat16, __bfloat16_as_ushort );
326 CALORECGPU_CUDACAST_HELPER( uint64_t, __nv_bfloat16, __bfloat16_as_ushort );
327 CALORECGPU_CUDACAST_HELPER( __nv_bfloat16, int16_t, __short_as_bfloat16 );
328 CALORECGPU_CUDACAST_HELPER( __nv_bfloat16, uint16_t, __ushort_as_bfloat16 );
329 CALORECGPU_CUDACAST_HELPER( __nv_bfloat16, int32_t, __ushort_as_bfloat16 );
330 CALORECGPU_CUDACAST_HELPER( __nv_bfloat16, uint32_t, __ushort_as_bfloat16 );
331 CALORECGPU_CUDACAST_HELPER( __nv_bfloat16, int64_t, __ushort_as_bfloat16 );
332 CALORECGPU_CUDACAST_HELPER( __nv_bfloat16, uint64_t, __ushort_as_bfloat16 );
333
334 This is apparently not working?! Why?!
335 */
336#endif
337
338 }
339
340
350 template <unsigned int mantiss, unsigned int exp, unsigned int tag = 1> struct IEEE754_like
351 {
352
353 static_assert(mantiss > 0 && exp > 0, "The exponent and mantissa must contain a positive number of bits!");
354
355 constexpr inline static unsigned int total_size_bits()
356 {
357 return mantiss + exp + 1;
358 }
359
360 constexpr inline static unsigned int mantissa_size_bits()
361 {
362 return mantiss;
363 }
364
365 constexpr inline static unsigned int exponent_size_bits()
366 {
367 return exp;
368 }
369
370 template <class T>
371 constexpr inline static T mantissa_mask()
372 {
373 static_assert(sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
374 "The type must be large enough to hold the bit representation of the floating point." );
375 T ret = (T(1) << mantiss) - 1;
376 return ret;
377 }
378
379 template <class T>
380 constexpr inline static T exponent_mask()
381 {
382 static_assert(sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
383 "The type must be large enough to hold the bit representation of the floating point." );
384
385 T ret = (T(1) << exp) - 1;
386 return ret << mantiss;
387 }
388
389 template <class T>
390 constexpr inline static T sign_mask()
391 {
392 static_assert(sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
393 "The type must be large enough to hold the bit representation of the floating point." );
394 T ret = T(1) << (exp + mantiss);
395 return ret;
396 }
397
398 template <class T>
399 constexpr inline static T full_mask()
400 {
402 }
403
404 template <class T>
405 constexpr inline static T exponent_bias()
406 {
407 static_assert(sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
408 "The type must be large enough to hold the bit representation of the floating point." );
409 return (T(1) << (exp - 1)) - 1;
410 }
411
412 template <class T>
413 constexpr inline static T max_exponent_with_bias()
414 {
415 static_assert(sizeof(T) * CHAR_BIT >= (mantiss + exp + 1),
416 "The type must be large enough to hold the bit representation of the floating point." );
417 return exponent_bias<T>() * 2;
418 }
419
420 template <class T>
421 constexpr inline static bool is_infinite(const T pattern)
422 {
423 return (pattern & (~sign_mask<T>())) == exponent_mask<T>();
424 }
425
426 template <class T>
427 constexpr inline static bool is_NaN(const T pattern)
428 {
429 return (pattern & (~sign_mask<T>())) > exponent_mask<T>();
430 //If it also has bits in the mantissa, it's greater than the mask.
431 //Last bit is sign, so signedness of T is of no concern.
432 }
433
434 template <class T>
435 constexpr inline static T absolute_value(const T pattern)
436 {
437 return pattern & (~sign_mask<T>());
438 }
439
443 template <class T>
444 constexpr inline static T to_total_ordering(const T pattern)
445 {
446 const T xor_mask = (!!(pattern & sign_mask<T>()) * full_mask<T>()) | sign_mask<T>();
447 return pattern ^ xor_mask;
448 }
449
453 template <class T>
454 constexpr inline static T from_total_ordering(const T pattern)
455 {
456 const T xor_mask = (!(pattern & sign_mask<T>()) * full_mask<T>()) | sign_mask<T>();
457 return pattern ^ xor_mask;
458 }
459
460 template <class T>
461 constexpr inline static T positive_zero()
462 {
463 return T(0);
464 }
465
466 template <class T>
467 constexpr inline static T negative_zero()
468 {
469 return sign_mask<T>();
470 }
471
472 template <class T>
473 constexpr inline static T positive_infinity()
474 {
475 return exponent_mask<T>();
476 }
477
478 template <class T>
479 constexpr inline static T negative_infinity()
480 {
481 return sign_mask<T>() | exponent_mask<T>();
482 }
483
484 template <class T>
485 constexpr inline static bool round_results(const bool is_negative, const bool is_odd,
486 const bool is_nearer_to_up, const bool is_tied,
487 RoundingModes rt)
488 {
489 switch (rt)
490 {
492 return !is_negative;
494 return is_negative;
496 return 0;
497 //Truncate => do nothing
499 return is_nearer_to_up || (is_odd && is_tied);
501 return is_nearer_to_up || is_tied;
502 default:
503 return 0;
504 }
505 }
506
507
511 template <class T>
512 constexpr inline static T add_patterns (const T a, const T b, const RoundingModes rt = RoundingModes::Default)
513 {
514 using namespace OperatorsHelper;
515
516 constexpr unsigned int extra_bits = 2;
517 //One sign and at least one exponent bit, we're safe!
518
519 constexpr T first_not_mantissa_bit = T(1) << mantissa_size_bits();
520
521 const T exp_a = (a & exponent_mask<T>()) >> mantissa_size_bits();
522 const T exp_b = (b & exponent_mask<T>()) >> mantissa_size_bits();
523
524 const bool a_denormal = (exp_a != 0);
525 const bool b_denormal = (exp_b != 0);
526
527 const bool use_second = (exp_a - exp_b) <= mantissa_size_bits() + 1 + extra_bits;
528 const bool is_negative = a & sign_mask<T>();
529
530 const T mantiss_a = ((a & mantissa_mask<T>()) | (first_not_mantissa_bit * a_denormal)) << extra_bits;
531 const T mantiss_b = ((b & mantissa_mask<T>()) | (first_not_mantissa_bit * b_denormal)) << extra_bits;
532 //To account for the overflow and rounding.
533
534 T mantiss_ret = mantiss_a;
535
536 mantiss_ret += safe_rshift(mantiss_b, exp_a - exp_b);
537
538 mantiss_ret |= !!(safe_lshift(mantiss_b, exp_a - exp_b) & mantissa_mask<T>()) * use_second;
539
540 const unsigned int leading_zeros = LeadingZerosPortability::count_leading_zeros<T>(mantiss_ret);
541 constexpr unsigned int desired_number_of_zeros = sizeof(T) * CHAR_BIT - mantissa_size_bits() - 1 - extra_bits;
542 const unsigned int shift_amount = clamped_sub(desired_number_of_zeros, leading_zeros);
543
544 const T last_bit_mask = T(1) << (shift_amount + extra_bits);
545 const T last_discarded_bit_mask = last_bit_mask >> 1;
546 const T round_mask = (last_bit_mask - 1) * !!(last_bit_mask);
547 const bool round_up = (mantiss_ret & round_mask) > last_discarded_bit_mask;
548 const bool tied = last_discarded_bit_mask && ((mantiss_ret & round_mask) == last_discarded_bit_mask);
549
550 bool round_bit = round_results<T>(is_negative, (mantiss_ret & last_bit_mask), round_up, tied, rt) && !!last_bit_mask;
551
552 mantiss_ret = safe_rshift(mantiss_ret, shift_amount + extra_bits);
553
554 mantiss_ret += round_bit * (shift_amount + extra_bits <= sizeof(T) * CHAR_BIT);
555
556 const T exponent_ret = exp_a + shift_amount + (exp_a == 0 && mantiss_ret > mantissa_mask<T>());
557
558 mantiss_ret &= mantissa_mask<T>();
559
560 mantiss_ret &= ~( ( exponent_ret > max_exponent_with_bias<T>() ) * mantissa_mask<T>() );
561 //If we somehow summed up to infinity,
562 //unset the remaining bits.
563
564 return (is_negative * sign_mask<T>()) | (exponent_ret << mantissa_size_bits()) | mantiss_ret;
565 }
566
570 template <class T>
571 constexpr inline static T subtract_patterns (const T a, const T b, const RoundingModes rt = RoundingModes::Default)
572 {
573 using namespace OperatorsHelper;
574
575 constexpr unsigned int extra_bits = 2;
576 //One sign and at least one exponent bit, we're safe!
577
578 constexpr T first_not_mantissa_bit = T(1) << mantissa_size_bits();
579
580 const T exp_a = (a & exponent_mask<T>()) >> mantissa_size_bits();
581 const T exp_b = (b & exponent_mask<T>()) >> mantissa_size_bits();
582
583 const bool use_second = (exp_a - exp_b) <= mantissa_size_bits() + 1 + extra_bits;
584 const bool is_negative = a & sign_mask<T>();
585
586 const T mantiss_a = ((a & mantissa_mask<T>()) | (first_not_mantissa_bit * (exp_a != 0))) << extra_bits;
587 const T mantiss_b = ((b & mantissa_mask<T>()) | (first_not_mantissa_bit * (exp_b != 0))) << extra_bits;
588 //To account for the overflow and rounding.
589
590 T mantiss_ret = mantiss_a;
591
592 mantiss_ret -= safe_rshift(mantiss_b, exp_a - exp_b) * use_second;
593
594 mantiss_ret |= !!(safe_lshift(-mantiss_b, exp_a - exp_b) & mantissa_mask<T>()) * use_second;
595
596 const unsigned int leading_zeros = LeadingZerosPortability::count_leading_zeros<T>(mantiss_ret);
597 constexpr unsigned int desired_number_of_zeros = sizeof(T) * CHAR_BIT - mantissa_size_bits() - 1 - extra_bits;
598 const unsigned int shift_amount = clamped_sub(leading_zeros, desired_number_of_zeros);
599
600 const T last_bit_mask = T(1) << extra_bits;
601 const T last_discarded_bit_mask = last_bit_mask >> 1;
602 const T round_mask = (last_bit_mask - 1) * !!(last_bit_mask);
603 const bool round_up = (mantiss_ret & round_mask) > last_discarded_bit_mask;
604 const bool tied = last_discarded_bit_mask && ((mantiss_ret & round_mask) == last_discarded_bit_mask);
605
606 bool round_bit = round_results<T>(is_negative, (mantiss_ret & last_bit_mask), round_up, tied, rt) && !!last_bit_mask;
607
608 mantiss_ret >>= extra_bits;
609
610 mantiss_ret += round_bit;
611
612 mantiss_ret = safe_lshift(mantiss_ret, shift_amount);
613
614 const T exponent_ret = clamped_sub(exp_a, shift_amount);
615
616 mantiss_ret = safe_rshift(mantiss_ret, clamped_sub(shift_amount, exp_a));
617
618 mantiss_ret &= mantissa_mask<T>();
619
620 return (is_negative * sign_mask<T>()) | (exponent_ret << mantissa_size_bits()) | mantiss_ret;
621 }
622
630 template <class T>
631 constexpr inline static T add(const T a, const T b, const RoundingModes rt = RoundingModes::Default)
632 {
633 const T abs_a = absolute_value<T>(a);
634 const T abs_b = absolute_value<T>(b);
635
636 const bool sign_a = a & sign_mask<T>();
637 const bool sign_b = b & sign_mask<T>();
638
639 if (abs_b == 0)
640 {
641 return a;
642 }
643 if (abs_a == 0)
644 {
645 return b;
646 }
647
649 {
650 if (sign_a == sign_b)
651 {
652 return a;
653 }
654 else
655 {
656 return abs_a | (T(1) << (mantissa_size_bits() - 1));
657 //A "quiet" NaN in most platforms.
658 }
659 }
660 else if (is_NaN<T>(a))
661 {
662 return a;
663 }
664 else if (is_NaN<T>(b))
665 {
666 return b;
667 }
668
669 if (sign_a == sign_b)
670 {
671 if (abs_a >= abs_b)
672 {
673 return add_patterns<T>(a, b, rt);
674 }
675 else
676 {
677 return add_patterns<T>(b, a, rt);
678 }
679 }
680 else
681 {
682 if (abs_a > abs_b)
683 {
684 return (sign_a * sign_mask<T>()) | subtract_patterns<T>(abs_a, abs_b, rt);
685 }
686 else if (abs_a == abs_b)
687 {
688 return 0;
689 }
690 else
691 {
692 return (sign_b * sign_mask<T>()) | subtract_patterns<T>(abs_b, abs_a, rt);
693 }
694 }
695 }
696
697
698 template <class T>
699 constexpr inline static T subtract(const T a, const T b, const RoundingModes rt = RoundingModes::Default)
700 {
701 return add<T>(a, b ^ sign_mask<T>(), rt);
702 }
703
704 };
705
706 template <class FLarge, class FSmall>
708 {
709 static_assert(FSmall::mantissa_size_bits() <= FLarge::mantissa_size_bits() &&
710 FSmall::exponent_size_bits() <= FLarge::exponent_size_bits() );
711
712
713 template <class T>
714 constexpr inline static T down_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
715 {
716 using FDest = FSmall;
717 using FSource = FLarge;
718 using namespace OperatorsHelper;
719
720 const bool sign_bit = pattern & FSource::template sign_mask<T>();
721
722 const T exponent = (pattern & FSource::template exponent_mask<T>()) >> FSource::mantissa_size_bits();
723 const T mantissa = pattern & FSource::template mantissa_mask<T>();
724
725 constexpr T delta_exponents = FSource::template exponent_bias<T>() - FDest::template exponent_bias<T>();
726
727 const bool exponent_full = (exponent > delta_exponents + FDest::template max_exponent_with_bias<T>());
728 const bool delete_mantissa = exponent_full && exponent <= FSource::template max_exponent_with_bias<T>();
729 //If the number is clamped to infinity, we must delete the mantissa
730 //so we don't get a NaN!
731 const bool denormal = exponent <= delta_exponents;
732 const bool zero = exponent + FDest::mantissa_size_bits() <= delta_exponents;
733
734 const T final_exponent = min(clamped_sub(exponent, delta_exponents), FDest::template max_exponent_with_bias<T>() + 1);
735
736 const T extra_mantissa_shift = clamped_sub(delta_exponents + 1, exponent) * denormal;
737 const T total_mantissa_shift = (FSource::mantissa_size_bits() - FDest::mantissa_size_bits()) + extra_mantissa_shift;
738
739 const T mantissa_keep_mask = safe_lshift(FDest::template mantissa_mask<T>(), total_mantissa_shift) & FSource::template mantissa_mask<T>();
740 const T check_for_rounding_mask = FSource::template mantissa_mask<T>() & (~mantissa_keep_mask);
741 const T last_mantissa_bit = safe_lshift(FDest::template mantissa_mask<T>(), total_mantissa_shift) & FSource::template mantissa_mask<T>();
742 const T first_discarded_bit_mask = last_mantissa_bit >> 1;
743 //In case total_mantissa_shift == 0, this is 0 too.
744
745 const T extra_denormal_bit = safe_rshift(T(denormal) << FDest::mantissa_size_bits(), extra_mantissa_shift);
746
747 const bool round_up = (mantissa & check_for_rounding_mask) > first_discarded_bit_mask;
748 const bool tie_break = ((mantissa & check_for_rounding_mask) == first_discarded_bit_mask);
749
750 const bool round_bit = FDest::template round_results<T>(sign_bit, mantissa & last_mantissa_bit, round_up, tie_break, rt) && !(exponent_full && !delete_mantissa && !denormal);
751 //The last part is so that NaN get truncated instead of rounded.
752
753 T final_mantissa = (safe_rshift(mantissa, total_mantissa_shift) | extra_denormal_bit) + round_bit;
754
755 final_mantissa *= !delete_mantissa;
756
757 return sign_bit * FDest::template sign_mask<T>() | ((final_exponent << FDest::mantissa_size_bits()) | final_mantissa) * !zero;
758
759 }
760
761 template <class T>
762 constexpr inline static T up_convert(const T pattern, [[maybe_unused]] const RoundingModes rt = RoundingModes::Default)
763 {
764 using FDest = FLarge;
765 using FSource = FSmall;
766 using namespace OperatorsHelper;
767
768 const bool sign_bit = pattern & FSource::template sign_mask<T>();
769 T exponent = (pattern & FSource::template exponent_mask<T>()) >> FSource::mantissa_size_bits();
770 T mantissa = pattern & FSource::template mantissa_mask<T>();
771
772 constexpr T delta_exponents = (FDest::template exponent_bias<T>() - FSource::template exponent_bias<T>());
773
774 if (exponent == 0 && FDest::exponent_size_bits() > FSource::exponent_size_bits())
775 {
776 const unsigned int leading_zeros = LeadingZerosPortability::count_leading_zeros<T>(mantissa);
777 const unsigned int shift_amount = leading_zeros - (sizeof(T) * CHAR_BIT - FSource::mantissa_size_bits()) + 1;
778 const unsigned int exponent_offset = (mantissa != 0) * (shift_amount - 1);
779 mantissa = (mantissa << shift_amount) & FSource::template mantissa_mask<T>();
780
781 exponent = delta_exponents - exponent_offset;
782 }
783 else if (exponent > FSource::template max_exponent_with_bias<T>())
784 //Infinity or NaN
785 {
786 exponent = FDest::template max_exponent_with_bias<T>() + 1;
787 }
788 else
789 {
790 exponent = exponent + delta_exponents;
791 }
792
793
794 return sign_bit * FDest::template sign_mask<T>() | (exponent << FDest::mantissa_size_bits()) | (mantissa << (FDest::mantissa_size_bits() - FSource::mantissa_size_bits()));
795
796 }
797 };
798
806 template <class T, class FDest, class FSource>
807 constexpr inline static T down_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
808 {
809 static_assert(FDest::mantissa_size_bits() <= FSource::mantissa_size_bits() &&
810 FDest::exponent_size_bits() <= FSource::exponent_size_bits(),
811 "The destination type must not be a larger floating point type than the source one.");
812
814 }
815
816
822 template <class T, class FDest, class FSource>
823 constexpr inline static T up_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
824 {
825 static_assert(FDest::mantissa_size_bits() >= FSource::mantissa_size_bits() &&
826 FDest::exponent_size_bits() >= FSource::exponent_size_bits(),
827 "The source type must not be a larger floating point type than the destination one.");
829 }
830
831
833
835
836 template <> template<class T>
837 constexpr inline T StandardFloat::add (const T a, const T b, const RoundingModes)
838 {
839 const float float_a = BitCastHelper::bitcast<float, T>(a);
840 const float float_b = BitCastHelper::bitcast<float, T>(b);
841
842 float float_ret = float_a + float_b;
843
845 }
846
847 template <> template<class T>
848 constexpr inline T StandardDouble::add (const T a, const T b, const RoundingModes)
849 {
850 const double double_a = BitCastHelper::bitcast<double, T>(a);
851 const double double_b = BitCastHelper::bitcast<double, T>(b);
852
853 double double_ret = double_a + double_b;
854
856 }
857
858 template <> template<class T>
859 constexpr inline T StandardFloat::subtract (const T a, const T b, const RoundingModes)
860 {
861 const float float_a = BitCastHelper::bitcast<float, T>(a);
862 const float float_b = BitCastHelper::bitcast<float, T>(b);
863
864 float float_ret = float_a - float_b;
865
867 }
868
869 template <> template<class T>
870 constexpr inline T StandardDouble::subtract (const T a, const T b, const RoundingModes)
871 {
872 const double double_a = BitCastHelper::bitcast<double, T>(a);
873 const double double_b = BitCastHelper::bitcast<double, T>(b);
874
875 double double_ret = double_a - double_b;
876
878 }
879
880 template<>
882 {
883 template <class T>
884 constexpr inline static T up_convert(const T pattern, [[maybe_unused]] const RoundingModes rt = RoundingModes::Default)
885 {
886 const float f = BitCastHelper::bitcast<float, uint32_t>(pattern);
887 const double d = f;
889 }
890 template <class T>
891 constexpr inline static T down_convert(const T pattern, [[maybe_unused]] const RoundingModes rt = RoundingModes::Default)
892 {
893 const double d = BitCastHelper::bitcast<double, uint64_t>(pattern);
894 const float f = d;
896 }
897 };
898
899
900 template<class Format>
901 struct ConversionHelper<Format, Format>
902 {
903 template <class T>
904 constexpr inline static T up_convert(const T pattern, [[maybe_unused]] const RoundingModes rt = RoundingModes::Default)
905 {
906 return pattern;
907 }
908 template <class T>
909 constexpr inline static T down_convert(const T pattern, [[maybe_unused]] const RoundingModes rt = RoundingModes::Default)
910 {
911 return pattern;
912 }
913 };
914
917
918#if defined (__CUDA_ARCH__) && CALORECGPU_INCLUDE_CUDA_SUPPORT
919
920 //If not, the CUDA-related ones will just default back to the slower, emulated operations.
921
922
923 template <> template<class T>
924 __device__ constexpr inline T CUDAHalfFloat::add (const T a, const T b, const RoundingModes)
925 {
926 const __half conv_a = BitCastHelper::bitcast<__half, T>(a);
927 const __half conv_b = BitCastHelper::bitcast<__half, T>(b);
928
929 __half conv_ret = __hadd(conv_a, conv_b);
930
932 }
933
934 template <> template<class T>
935 __device__ constexpr inline T CUDAHalfFloat::subtract (const T a, const T b, const RoundingModes)
936 {
937 const __nv_bfloat16 conv_a = BitCastHelper::bitcast<__half, T>(a);
938 const __nv_bfloat16 conv_b = BitCastHelper::bitcast<__half, T>(b);
939
940 __half conv_ret = __hsub(conv_a, conv_b);
941
943 }
944
945 template <> template<class T>
946 __device__ constexpr inline T CUDABFloat16::add (const T a, const T b, const RoundingModes)
947 {
948 const __nv_bfloat16 conv_a = BitCastHelper::bitcast<__nv_bfloat16, T>(a);
949 const __nv_bfloat16 conv_b = BitCastHelper::bitcast<__nv_bfloat16, T>(b);
950
951 __nv_bfloat16 conv_ret = __hadd(conv_a, conv_b);
952
954 }
955
956 template <> template<class T>
957 __device__ constexpr inline T CUDABFloat16::subtract (const T a, const T b, const RoundingModes)
958 {
959 const __nv_bfloat16 conv_a = BitCastHelper::bitcast<__nv_bfloat16, T>(a);
960 const __nv_bfloat16 conv_b = BitCastHelper::bitcast<__nv_bfloat16, T>(b);
961
962 __nv_bfloat16 conv_ret = __hsub(conv_a, conv_b);
963
965 }
966
967 template<>
969 {
970 template <class T>
971 constexpr inline static T up_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
972 {
973 const __half hf = BitCastHelper::bitcast<__half, T>(pattern);
974 const float fl = __half2float(hf);
976 }
977 template <class T>
978 constexpr inline static T down_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
979 {
980 const float pre_conv = BitCastHelper::bitcast<float, T>(pattern);
981 __half ret;
982 switch (rt)
983 {
985 ret = __float2half_ru(pre_conv);
986 break;
988 ret = __float2half_rd(pre_conv);
989 break;
991 ret = __float2half_rz(pre_conv);
992 break;
994 ret = __float2half_rn(pre_conv);
995 break;
996 //case RoundingModes::ToNearestAwayFromZero:
997 //No support for this
998 default:
999 ret = __float2half(pre_conv);
1000 break;
1001 }
1003 }
1004 };
1005
1006
1007 template<>
1009 {
1010 template <class T>
1011 constexpr inline static T up_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
1012 {
1013 const __nv_bfloat16 hf = BitCastHelper::bitcast<__nv_bfloat16, T>(pattern);
1014 const float fl = __bfloat162float(hf);
1016 }
1017 template <class T>
1018 constexpr inline static T down_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
1019 {
1020 const float pre_conv = BitCastHelper::bitcast<float, T>(pattern);
1021 __nv_bfloat16 ret;
1022 switch (rt)
1023 {
1025 ret = __float2bfloat16_ru(pre_conv);
1026 break;
1028 ret = __float2bfloat16_rd(pre_conv);
1029 break;
1031 ret = __float2bfloat16_rz(pre_conv);
1032 break;
1034 ret = __float2bfloat16_rn(pre_conv);
1035 break;
1036 //case RoundingModes::ToNearestAwayFromZero:
1037 //No support for this
1038 default:
1039 ret = __float2bfloat16(pre_conv)
1040 break;
1041 }
1043 }
1044 };
1045
1046 template<>
1048 {
1049 template <class T>
1050 constexpr inline static T up_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
1051 {
1052 const T first_step = ConversionHelper<StandardFloat, CUDAHalfFloat>::template up_convert<T>(pattern, rt);
1053 return ConversionHelper<StandardDouble, StandardFloat>::template up_convert<T>(first_step, rt);
1054 }
1055 template <class T>
1056 constexpr inline static T down_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
1057 {
1059 {
1060 const double d = BitCastHelper::bitcast<double, uint64_t>(pattern);
1061 return BitCastHelper::bitcast<T, __half>(__double2half(d));
1062 }
1063 else
1064 {
1065 const T first_step = ConversionHelper<StandardDouble, StandardFloat>::template down_convert<T>(first_step, rt);
1066 return ConversionHelper<StandardFloat, CUDAHalfFloat>::template down_convert<T>(first_step, rt);
1067 }
1068 }
1069 };
1070
1071 template<>
1073 {
1074 template <class T>
1075 constexpr inline static T up_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
1076 {
1077 const T first_step = ConversionHelper<StandardFloat, CUDABFloat16>::template up_convert<T>(first_step, rt);
1078 return ConversionHelper<StandardDouble, StandardFloat>::template up_convert<T>(first_step, rt);
1079 }
1080 template <class T>
1081 constexpr inline static T down_convert(const T pattern, const RoundingModes rt = RoundingModes::Default)
1082 {
1084 {
1085 const double d = BitCastHelper::bitcast<double, uint64_t>(pattern);
1086 return BitCastHelper::bitcast<T, __nv_bfloat16>(__double2bfloat16(d));
1087 }
1088 else
1089 {
1090 const T first_step = ConversionHelper<StandardDouble, StandardFloat>::template down_convert<T>(first_step, rt);
1091 return ConversionHelper<StandardFloat, CUDABFloat16>::template down_convert<T>(first_step, rt);
1092 }
1093 }
1094 };
1095
1096#endif
1097
1098}
1099
1100
1101#endif
#define CALORECGPU_MULTIPLE_PORTABILITY_CLZ_FUNC_HELPER(ATTRIB, TYPE, BUILTIN)
Definition FPHelpers.h:94
static Double_t a
#define x
#define min(a, b)
Definition cfImp.cxx:40
#define max(a, b)
Definition cfImp.cxx:41
void zero(TH2 *h)
zero the contents of a 2d histogram
static constexpr To bitcast(const From &x)
Definition FPHelpers.h:275
static constexpr unsigned int count_leading_zeros(const T num)
Definition FPHelpers.h:76
static constexpr T bit_or(const T x1, const T x2)
Definition FPHelpers.h:251
static constexpr T safe_lshift(const T x, const T amount)
Definition FPHelpers.h:204
static constexpr T bit_and(const T x1, const T x2)
Definition FPHelpers.h:245
static constexpr T clamp(const T x, const T low, const T high)
Definition FPHelpers.h:237
static constexpr T clamped_sub(const T x1, const T x2)
Definition FPHelpers.h:219
static constexpr T safe_rshift(const T x, const T amount)
Definition FPHelpers.h:211
IEEE754_like< 52, 11 > StandardDouble
Definition FPHelpers.h:834
IEEE754_like< 7, 8 > CUDABFloat16
Definition FPHelpers.h:916
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Converts pattern from the larger floating point format FSource to FDest.
Definition FPHelpers.h:807
IEEE754_like< 23, 8 > StandardFloat
Definition FPHelpers.h:832
RoundingModes
Specifies the rounding mode to use for the operations.
Definition FPHelpers.h:58
IEEE754_like< 10, 5 > CUDAHalfFloat
Definition FPHelpers.h:915
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Converts pattern from the smaller floating point format FSource to FDest.
Definition FPHelpers.h:823
unsigned long long T
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:904
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:909
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:884
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:891
static constexpr T up_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:762
static constexpr T down_convert(const T pattern, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:714
Specifies a floating point format like those described in IEEE-754, with an adjustable number of bits...
Definition FPHelpers.h:351
static constexpr unsigned int mantissa_size_bits()
Definition FPHelpers.h:360
static constexpr T add_patterns(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
The absolute value of must be greater than or equal than that of .
Definition FPHelpers.h:512
static constexpr bool is_NaN(const T pattern)
Definition FPHelpers.h:427
static constexpr T exponent_bias()
Definition FPHelpers.h:405
static constexpr T positive_infinity()
Definition FPHelpers.h:473
static constexpr T exponent_mask()
Definition FPHelpers.h:380
static constexpr T add(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:631
static constexpr T subtract(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
Definition FPHelpers.h:699
static constexpr T negative_infinity()
Definition FPHelpers.h:479
static constexpr T negative_zero()
Definition FPHelpers.h:467
static constexpr T from_total_ordering(const T pattern)
Definition FPHelpers.h:454
static constexpr T full_mask()
Definition FPHelpers.h:399
static constexpr T to_total_ordering(const T pattern)
Definition FPHelpers.h:444
static constexpr unsigned int exponent_size_bits()
Definition FPHelpers.h:365
static constexpr T max_exponent_with_bias()
Definition FPHelpers.h:413
static constexpr T mantissa_mask()
Definition FPHelpers.h:371
static constexpr bool round_results(const bool is_negative, const bool is_odd, const bool is_nearer_to_up, const bool is_tied, RoundingModes rt)
Definition FPHelpers.h:485
static constexpr T positive_zero()
Definition FPHelpers.h:461
static constexpr T absolute_value(const T pattern)
Definition FPHelpers.h:435
static constexpr T sign_mask()
Definition FPHelpers.h:390
static constexpr bool is_infinite(const T pattern)
Definition FPHelpers.h:421
static constexpr T subtract_patterns(const T a, const T b, const RoundingModes rt=RoundingModes::Default)
The absolute value of must be greater than or equal than that of .
Definition FPHelpers.h:571
static constexpr unsigned int total_size_bits()
Definition FPHelpers.h:355