16 asg::AsgMessaging(
"TauGNN"),
18 m_config{
config}, m_useTRT(useTRT)
25 FlavorTagInference::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig();
28 for (
const auto& out_node: gnn_output_config) {
35 auto graph_config = m_saltModel->getGraphConfig();
43 return in_node.name ==
config.input_layer_scalar;
46 return in_node.name ==
config.input_layer_tracks;
49 return in_node.name ==
config.input_layer_clusters;
52 auto scalar_node = std::find_if(graph_config.inputs.cbegin(),
53 graph_config.inputs.cend(),
56 auto track_node = std::find_if(graph_config.input_sequences.cbegin(),
57 graph_config.input_sequences.cend(),
60 auto cluster_node = std::find_if(graph_config.input_sequences.cbegin(),
61 graph_config.input_sequences.cend(),
65 auto has_scalar_node = scalar_node != graph_config.inputs.cend();
66 auto has_track_node = track_node != graph_config.input_sequences.cend();
67 auto has_cluster_node = cluster_node != graph_config.input_sequences.cend();
68 if(!has_scalar_node)
ATH_MSG_WARNING(
"No scalar node with name "<<
config.input_layer_scalar<<
" found!");
69 if(!has_track_node)
ATH_MSG_WARNING(
"No track node with name "<<
config.input_layer_tracks<<
" found!");
70 if(!has_cluster_node)
ATH_MSG_WARNING(
"No cluster node with name "<<
config.input_layer_clusters<<
" found!");
73 if (has_scalar_node) {
74 for (
const auto &in : scalar_node->variables) {
75 std::string
name = in.name;
76 m_scalarCalc_inputs.push_back(
name);
81 for (
const auto &in : track_node->variables) {
82 std::string
name = in.name;
83 m_trackCalc_inputs.push_back(
name);
87 if (has_cluster_node) {
88 for (
const auto &in : cluster_node->variables) {
89 std::string
name = in.name;
90 m_clusterCalc_inputs.push_back(
name);
94 m_var_calc = std::make_unique<TauGNNUtils::GNNVarCalc>(m_useTRT);
101 std::map<std::string, float>,
102 std::map<std::string, std::vector<char>>,
103 std::map<std::string, std::vector<float>> >
105 const std::vector<const xAOD::TauTrack *> &tracks,
106 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters)
const {
107 std::map<std::string, Inputs> gnn_input;
112 std::vector<int64_t> tau_feats_dim = {
static_cast<int64_t
>(1),
static_cast<int64_t
>(tau_feats.size())};
113 std::vector<int64_t> trk_feats_dim = {
static_cast<int64_t
>(tracks.size()),
static_cast<int64_t
>(
m_trackCalc_inputs.size())};
116 Inputs tau_info (tau_feats, tau_feats_dim);
117 Inputs trk_info (trk_feats, trk_feats_dim);
118 Inputs cls_info (cls_feats, cls_feats_dim);
120 gnn_input.insert({
"tau_vars", tau_info});
121 gnn_input.insert({
"track_vars", trk_info});
122 gnn_input.insert({
"cluster_vars", cls_info});
126 auto [out_f, out_vc, out_vf] =
m_saltModel->runInference(gnn_input);
128 return std::make_tuple(out_f, out_vc, out_vf);
131 std::tuple<std::vector<float>, std::vector<float>, std::vector<float>>
134 const std::vector<const xAOD::TauTrack *> &tracks,
135 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters
138 std::vector<float> tau_feats;
139 std::vector<std::vector<float>> track_feats_2d, cluster_feats_2d;
150 std::vector<float> track_feats =
flatten(track_feats_2d);
151 std::vector<float> cluster_feats =
flatten(cluster_feats_2d);
152 return std::make_tuple(tau_feats, track_feats, cluster_feats);