ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelOutput.cxx
Go to the documentation of this file.
1/*
2Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3
4This class is used to store the configuration for a ONNX output node.
5*/
6
8
9namespace FlavorTagInference {
10
11/* constructor for SaltModelVersion::V1 and higher */
13 const ONNXTensorElementDataType type,
14 int rank)
15 : name(name),
17 type(getOutputType(type, rank)){}
18
19/* constructor for SaltModelVersion::V0 */
21 const ONNXTensorElementDataType type,
22 const std::string& model_name)
23 : name(getName(name, model_name)),
26
27const std::string SaltModelOutput::getName(const std::string& name, const std::string& model_name) const {
28 // unfortunately, this is block is needed to support some taggers that we schedule that don't have
29 // a well defined model name and rely on output remapping.
30 if (model_name == "UnknownModelName") {
31 return name;
32 }
33 return model_name + "_" + name;
34}
35
36SaltModelOutput::OutputType SaltModelOutput::getOutputType(ONNXTensorElementDataType type, int rank) const {
37 // Determine the output node type based on the type and shape of the output tensor.
38 using ORT = ONNXTensorElementDataType;
39 if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
40 if (rank == 0) {
41 return OutputType::FLOAT;
42 } else if (rank == 1) {
44 }
45 } else if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
47 }
49}
50
51} // namespace FlavorTagInference
OutputType getOutputType(ONNXTensorElementDataType type, int rank) const
const std::string getName(const std::string &name, const std::string &model_name) const
SaltModelOutput(const std::string &name, ONNXTensorElementDataType type, int rank)
This file contains "getter" functions used for accessing tagger inputs from the EDM.