13 :
14 FlavorTagInference::SaltModelEDMLoaderBase(salt_model),
15 asg::AsgMessaging("TauGNNDataLoader")
16 {
18 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* scalar_input_node = nullptr;
19 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* track_input_node = nullptr;
20 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* cluster_input_node = nullptr;
21 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* hit_input_node = nullptr;
23 if (in_node.name ==
config.input_layer_scalar) {
24 scalar_input_node = &in_node;
26 break;
27 }
28 }
29 for (
const auto &in_node :
graph_config.input_sequences) {
30 if (in_node.name ==
config.input_layer_tracks) {
31 track_input_node = &in_node;
33 }
34 if (in_node.name ==
config.input_layer_clusters) {
35 cluster_input_node = &in_node;
37 }
38 if (in_node.name ==
config.input_layer_hits) {
39 hit_input_node = &in_node;
41 }
42 }
43
44
45 if (scalar_input_node) {
46 for (
const auto &in : scalar_input_node->
variables) {
48 }
49 }
else if(!
config.input_layer_scalar.empty()) {
50 ATH_MSG_ERROR(
"Scalar input node '" +
config.input_layer_scalar +
"' not found in the model input configuration");
51 throw std::runtime_error(
"Scalar input node '" +
config.input_layer_scalar +
"' not found in the model input configuration");
52 }
53
54 if (track_input_node) {
55 FlavorTagInference::ConstituentsInputConfig trk_config;
56 trk_config.
name =
"tautracks";
63 for (
const auto &in : track_input_node->
variables) {
64 if (!
config.useTRT && (in.name ==
"eProbabilityHT")) {
65 ATH_MSG_WARNING(
"Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
67 continue;
68 }
70 }
71 addVectorLoader(
config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
72 }
else if(!
config.input_layer_tracks.empty() &&
config.n_max_tracks > 0) {
73 ATH_MSG_ERROR(
"Track input node '" +
config.input_layer_tracks +
"' not found in the model input configuration");
74 throw std::runtime_error(
"Track input node '" +
config.input_layer_tracks +
"' not found in the model input configuration");
75 }
76
77 if (cluster_input_node) {
78 FlavorTagInference::ConstituentsInputConfig cls_config;
79 cls_config.
name =
"tauclusters";
86 for (
const auto &in : cluster_input_node->
variables) {
88 }
89 addVectorLoader(
config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config,
config.max_dr_cluster,
config.doVertexCorrection));
90 }
else if(!
config.input_layer_clusters.empty() &&
config.n_max_clusters > 0) {
91 ATH_MSG_ERROR(
"Cluster input node '" +
config.input_layer_clusters +
"' not found in the model input configuration");
92 throw std::runtime_error(
"Cluster input node '" +
config.input_layer_clusters +
"' not found in the model input configuration");
93 }
94
95 if (hit_input_node) {
96 FlavorTagInference::ConstituentsInputConfig cls_config;
97 cls_config.
name =
"tauhits";
104 for (
const auto &in : hit_input_node->
variables) {
106 }
107 addVectorLoader(
config.input_layer_hits, std::make_shared<FlavorTagInference::ConstituentLoaderTauHit>(cls_config,
config.hits_decor_name));
108 }
else if(!
config.input_layer_hits.empty() &&
config.n_max_hits > 0) {
109 ATH_MSG_ERROR(
"Hit input node '" +
config.input_layer_hits +
"' not found in the model input configuration");
110 throw std::runtime_error(
"Hit input node '" +
config.input_layer_hits +
"' not found in the model input configuration");
111 }
112}
#define ATH_MSG_WARNING(x)
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
SaltModelGraphConfig::GraphConfig graph_config
std::string scalarInputName
ScalarCalc_t getScalarCalc(const std::string &name) const