ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelOutput.cxx
Go to the documentation of this file.
1/*
2Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3
4This class is used to store the configuration for a ONNX output node.
5*/
6
8
9namespace FlavorTagInference {
10
11/* constructor for SaltModelVersion::V1 and higher */
13 const ONNXTensorElementDataType type,
14 int rank)
15 : name(name),
17 type(getOutputType(type, rank)){}
18
19/* constructor for SaltModelVersion::V0 */
21 const ONNXTensorElementDataType type,
22 const std::string& model_name)
23 : name(getName(name, model_name)),
26
27/* constructor for parametric reduced-precision float32 (VECTRUNCFLOAT with explicit E,M) */
30 float scale,
31 int exp_bits_in,
32 int man_bits_in)
33 : name(name),
35 type(type),
36 scale(scale),
37 exp_bits(exp_bits_in),
38 man_bits(man_bits_in){}
39
40const std::string SaltModelOutput::getName(const std::string& name, const std::string& model_name) {
41 // unfortunately, this is block is needed to support some taggers that we schedule that don't have
42 // a well defined model name and rely on output remapping.
43 if (model_name == "UnknownModelName") {
44 return name;
45 }
46 return model_name + "_" + name;
47}
48
49SaltModelOutput::OutputType SaltModelOutput::getOutputType(ONNXTensorElementDataType type, int rank) const {
50 // Determine the output node type based on the type and shape of the output tensor.
51 using ORT = ONNXTensorElementDataType;
52 if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
53 if (rank == 0) {
54 return OutputType::FLOAT;
55 } else if (rank == 1) {
57 }
58 } else if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
60 } else if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
61 if (rank == 1) {
62 return OutputType::VECINT;
63 }
64 }
66}
67
68} // namespace FlavorTagInference
OutputType getOutputType(ONNXTensorElementDataType type, int rank) const
SaltModelOutput(const std::string &name, ONNXTensorElementDataType type, int rank)
static const std::string getName(const std::string &name, const std::string &model_name)
This file contains "getter" functions used for accessing tagger inputs from the EDM.