ATLAS Offline Software
OnnxOutput.cxx
Go to the documentation of this file.
1 /*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 
4 This class is used to store the configuration for a ONNX output node.
5 */
6 
8 
9 namespace FlavorTagDiscriminants {
10 
11 /* constructor for OnnxModelVersion::V1 and higher */
12 OnnxOutput::OnnxOutput(const std::string& name,
13  const ONNXTensorElementDataType type,
14  int rank)
15  : name(name),
16  name_in_model(name),
17  type(getOutputType(type, rank)){}
18 
19 /* constructor for OnnxModelVersion::V0 */
20 OnnxOutput::OnnxOutput(const std::string& name,
21  const ONNXTensorElementDataType type,
22  const std::string& model_name)
23  : name(getName(name, model_name)),
24  name_in_model(name),
25  type(getOutputType(type, 0)){}
26 
27 const std::string OnnxOutput::getName(const std::string& name, const std::string& model_name) const {
28  // unfortunately, this is block is needed to support some taggers that we schedule that don't have
29  // a well defined model name and rely on output remapping.
30  if (model_name == "UnknownModelName") {
31  return name;
32  }
33  return model_name + "_" + name;
34 }
35 
36 OnnxOutput::OutputType OnnxOutput::getOutputType(ONNXTensorElementDataType type, int rank) const {
37  // Determine the output node type based on the type and shape of the output tensor.
38  using ORT = ONNXTensorElementDataType;
39  if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
40  if (rank == 0) {
41  return OutputType::FLOAT;
42  } else if (rank == 1) {
43  return OutputType::VECFLOAT;
44  }
45  } else if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
46  return OutputType::VECCHAR;
47  }
48  return OutputType::UNKNOWN;
49 }
50 
51 } // namespace FlavorTagDiscriminants
FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR
@ VECCHAR
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT
@ FLOAT
FlavorTagDiscriminants::OnnxOutput::name
const std::string name
Definition: OnnxOutput.h:31
FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT
@ VECFLOAT
FlavorTagDiscriminants::OnnxOutput::getOutputType
OutputType getOutputType(ONNXTensorElementDataType type, int rank) const
Definition: OnnxOutput.cxx:36
FlavorTagDiscriminants::OnnxOutput::type
const OutputType type
Definition: OnnxOutput.h:33
FlavorTagDiscriminants::OnnxOutput::getName
const std::string getName(const std::string &name, const std::string &model_name) const
Definition: OnnxOutput.cxx:27
FlavorTagDiscriminants::OnnxOutput::OnnxOutput
OnnxOutput(const std::string &name, ONNXTensorElementDataType type, int rank)
Definition: OnnxOutput.cxx:12
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
OnnxOutput.h
FlavorTagDiscriminants::OnnxOutput::OutputType
OutputType
Definition: OnnxOutput.h:19
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
FlavorTagDiscriminants::OnnxOutput::OutputType::UNKNOWN
@ UNKNOWN