ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelOutput.h
Go to the documentation of this file.
1/*
2Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3
4This class is used to store the configuration for a ONNX output node.
5*/
6
7#ifndef FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H
8#define FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H
9
10#include <onnxruntime_cxx_api.h>
11#include "nlohmann/json.hpp"
12#include <string>
13
14namespace FlavorTagInference {
15
17
18 public:
20
21 /* constructor for SaltModelVersion::V1 and higher */
22 SaltModelOutput(const std::string& name,
23 ONNXTensorElementDataType type,
24 int rank);
25
26 /* constructor for SaltModelVersion::V0 */
27 SaltModelOutput(const std::string& name,
28 ONNXTensorElementDataType type,
29 const std::string& name_in_model);
30
31 /* constructor for parametric reduced-precision float32 (VECTRUNCFLOAT with explicit E,M) */
32 SaltModelOutput(const std::string& name,
34 float scale,
35 int exp_bits,
36 int man_bits);
37
38 const std::string name;
39 const std::string name_in_model;
41 const float scale{1.0f};
42 int exp_bits{8}; // exponent bits for VECTRUNCFLOAT (default: bf16-equivalent)
43 int man_bits{7}; // mantissa bits for VECTRUNCFLOAT (default: bf16-equivalent)
44
45 private:
46 OutputType getOutputType(ONNXTensorElementDataType type, int rank) const;
47 static const std::string getName(const std::string& name, const std::string& model_name);
48
49}; // class SaltModelOutput
50
51} // namespace FlavorTagInference
52
53#endif // FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H
OutputType getOutputType(ONNXTensorElementDataType type, int rank) const
SaltModelOutput(const std::string &name, ONNXTensorElementDataType type, int rank)
static const std::string getName(const std::string &name, const std::string &model_name)
This file contains "getter" functions used for accessing tagger inputs from the EDM.