ATLAS Offline Software
SaltModelOutput.cxx
Go to the documentation of this file.
1 /*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 
4 This class is used to store the configuration for a ONNX output node.
5 */
6 
8 
9 namespace FlavorTagInference {
10 
11 /* constructor for SaltModelVersion::V1 and higher */
13  const ONNXTensorElementDataType type,
14  int rank)
15  : name(name),
16  name_in_model(name),
17  type(getOutputType(type, rank)){}
18 
19 /* constructor for SaltModelVersion::V0 */
21  const ONNXTensorElementDataType type,
22  const std::string& model_name)
23  : name(getName(name, model_name)),
24  name_in_model(name),
25  type(getOutputType(type, 0)){}
26 
27 const std::string SaltModelOutput::getName(const std::string& name, const std::string& model_name) const {
28  // unfortunately, this is block is needed to support some taggers that we schedule that don't have
29  // a well defined model name and rely on output remapping.
30  if (model_name == "UnknownModelName") {
31  return name;
32  }
33  return model_name + "_" + name;
34 }
35 
36 SaltModelOutput::OutputType SaltModelOutput::getOutputType(ONNXTensorElementDataType type, int rank) const {
37  // Determine the output node type based on the type and shape of the output tensor.
38  using ORT = ONNXTensorElementDataType;
39  if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
40  if (rank == 0) {
41  return OutputType::FLOAT;
42  } else if (rank == 1) {
43  return OutputType::VECFLOAT;
44  }
45  } else if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
46  return OutputType::VECCHAR;
47  }
48  return OutputType::UNKNOWN;
49 }
50 
51 } // namespace FlavorTagInference
FlavorTagInference::SaltModelOutput::OutputType::VECCHAR
@ VECCHAR
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
FlavorTagInference::SaltModelOutput::SaltModelOutput
SaltModelOutput(const std::string &name, ONNXTensorElementDataType type, int rank)
Definition: SaltModelOutput.cxx:12
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
dumpTruth.getName
getName
Definition: dumpTruth.py:34
FlavorTagInference::SaltModelOutput::OutputType::UNKNOWN
@ UNKNOWN
SaltModelOutput.h
FlavorTagInference::SaltModelOutput::type
const OutputType type
Definition: SaltModelOutput.h:33
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
FlavorTagInference::SaltModelOutput::name
const std::string name
Definition: SaltModelOutput.h:31
FlavorTagInference::SaltModelOutput::OutputType::FLOAT
@ FLOAT
FlavorTagInference::SaltModelOutput::OutputType
OutputType
Definition: SaltModelOutput.h:19
FlavorTagInference::SaltModelOutput::getName
const std::string getName(const std::string &name, const std::string &model_name) const
Definition: SaltModelOutput.cxx:27
FlavorTagInference::SaltModelOutput::OutputType::VECFLOAT
@ VECFLOAT
FlavorTagInference::SaltModelOutput::getOutputType
OutputType getOutputType(ONNXTensorElementDataType type, int rank) const
Definition: SaltModelOutput.cxx:36