ATLAS Offline Software
Loading...
Searching...
No Matches
PhysicsAnalysis
JetTagging
FlavorTagInference
Root
SaltModelOutput.cxx
Go to the documentation of this file.
1
/*
2
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3
4
This class is used to store the configuration for a ONNX output node.
5
*/
6
7
#include "
FlavorTagInference/SaltModelOutput.h
"
8
9
namespace
FlavorTagInference
{
10
11
/* constructor for SaltModelVersion::V1 and higher */
12
SaltModelOutput::SaltModelOutput
(
const
std::string&
name
,
13
const
ONNXTensorElementDataType
type
,
14
int
rank)
15
:
name
(
name
),
16
name_in_model
(
name
),
17
type
(
getOutputType
(
type
, rank)){}
18
19
/* constructor for SaltModelVersion::V0 */
20
SaltModelOutput::SaltModelOutput
(
const
std::string&
name
,
21
const
ONNXTensorElementDataType
type
,
22
const
std::string& model_name)
23
:
name
(
getName
(
name
, model_name)),
24
name_in_model
(
name
),
25
type
(
getOutputType
(
type
, 0)){}
26
27
const
std::string
SaltModelOutput::getName
(
const
std::string&
name
,
const
std::string& model_name)
const
{
28
// unfortunately, this is block is needed to support some taggers that we schedule that don't have
29
// a well defined model name and rely on output remapping.
30
if
(model_name ==
"UnknownModelName"
) {
31
return
name
;
32
}
33
return
model_name +
"_"
+
name
;
34
}
35
36
SaltModelOutput::OutputType
SaltModelOutput::getOutputType
(ONNXTensorElementDataType
type
,
int
rank)
const
{
37
// Determine the output node type based on the type and shape of the output tensor.
38
using
ORT = ONNXTensorElementDataType;
39
if
(
type
== ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
40
if
(rank == 0) {
41
return
OutputType::FLOAT
;
42
}
else
if
(rank == 1) {
43
return
OutputType::VECFLOAT
;
44
}
45
}
else
if
(
type
== ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
46
return
OutputType::VECCHAR
;
47
}
48
return
OutputType::UNKNOWN
;
49
}
50
51
}
// namespace FlavorTagInference
SaltModelOutput.h
FlavorTagInference::SaltModelOutput::getOutputType
OutputType getOutputType(ONNXTensorElementDataType type, int rank) const
Definition
SaltModelOutput.cxx:36
FlavorTagInference::SaltModelOutput::type
const OutputType type
Definition
SaltModelOutput.h:33
FlavorTagInference::SaltModelOutput::getName
const std::string getName(const std::string &name, const std::string &model_name) const
Definition
SaltModelOutput.cxx:27
FlavorTagInference::SaltModelOutput::name
const std::string name
Definition
SaltModelOutput.h:31
FlavorTagInference::SaltModelOutput::SaltModelOutput
SaltModelOutput(const std::string &name, ONNXTensorElementDataType type, int rank)
Definition
SaltModelOutput.cxx:12
FlavorTagInference::SaltModelOutput::name_in_model
const std::string name_in_model
Definition
SaltModelOutput.h:32
FlavorTagInference::SaltModelOutput::OutputType
OutputType
Definition
SaltModelOutput.h:19
FlavorTagInference::SaltModelOutput::OutputType::VECCHAR
@ VECCHAR
Definition
SaltModelOutput.h:19
FlavorTagInference::SaltModelOutput::OutputType::UNKNOWN
@ UNKNOWN
Definition
SaltModelOutput.h:19
FlavorTagInference::SaltModelOutput::OutputType::VECFLOAT
@ VECFLOAT
Definition
SaltModelOutput.h:19
FlavorTagInference::SaltModelOutput::OutputType::FLOAT
@ FLOAT
Definition
SaltModelOutput.h:19
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition
PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
Generated on
for ATLAS Offline Software by
1.14.0