13 :
14 FlavorTagInference::SaltModelEDMLoaderBase(salt_model),
15 asg::AsgMessaging("TauGNNDataLoader")
16 {
18 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* scalar_input_node = nullptr;
19 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* track_input_node = nullptr;
20 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* cluster_input_node = nullptr;
22 if (in_node.name ==
config.input_layer_scalar) {
23 scalar_input_node = &in_node;
25 break;
26 }
27 }
28 for (
const auto &in_node :
graph_config.input_sequences) {
29 if (in_node.name ==
config.input_layer_tracks) {
30 track_input_node = &in_node;
32 }
33 if (in_node.name ==
config.input_layer_clusters) {
34 cluster_input_node = &in_node;
36 }
37 }
38
39
40 if (scalar_input_node) {
41 for (
const auto &in : scalar_input_node->
variables) {
43 }
44 } else {
45 ATH_MSG_ERROR(
"Scalar input node 'tau_vars' not found in the model input configuration");
46 throw std::runtime_error("Scalar input node 'tau_vars' not found in the model input configuration");
47 }
48
49 if (track_input_node) {
50 FlavorTagInference::ConstituentsInputConfig trk_config;
51 trk_config.
name =
"tautracks";
58 for (
const auto &in : track_input_node->
variables) {
59 if (!
config.useTRT && (in.name ==
"eProbabilityHT")) {
60 ATH_MSG_WARNING(
"Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
62 continue;
63 }
65 }
66 addVectorLoader(
config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
67 } else {
68 ATH_MSG_ERROR(
"Track input node '" +
config.input_layer_tracks +
"' not found in the model input configuration");
69 throw std::runtime_error(
"Track input node '" +
config.input_layer_tracks +
"' not found in the model input configuration");
70 }
71
72 if (cluster_input_node) {
73 FlavorTagInference::ConstituentsInputConfig cls_config;
74 cls_config.
name =
"tauclusters";
81 for (
const auto &in : cluster_input_node->
variables) {
83 }
84 addVectorLoader(
config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config,
config.max_dr_cluster,
config.doVertexCorrection));
85 } else {
86 ATH_MSG_ERROR(
"Cluster input node '" +
config.input_layer_clusters +
"' not found in the model input configuration");
87 throw std::runtime_error(
"Cluster input node '" +
config.input_layer_clusters +
"' not found in the model input configuration");
88 }
89}
#define ATH_MSG_WARNING(x)
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
SaltModelGraphConfig::GraphConfig graph_config
std::string scalarInputName
ScalarCalc_t getScalarCalc(const std::string &name) const