7 #include "lwtnn/parse_json.hh"
16 asg::AsgMessaging(
"TauGNN"),
23 m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile);
36 auto lwtnn_config =
m_onnxUtil->getLwtConfig();
43 auto node_is_scalar = [&
config](
const lwt::InputNodeConfig &in_node) {
44 return in_node.name ==
config.input_layer_scalar;
46 auto node_is_track = [&
config](
const lwt::InputNodeConfig &in_node) {
47 return in_node.name ==
config.input_layer_tracks;
49 auto node_is_cluster = [&
config](
const lwt::InputNodeConfig &in_node) {
50 return in_node.name ==
config.input_layer_clusters;
53 auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
54 lwtnn_config.inputs.cend(),
57 auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
58 lwtnn_config.input_sequences.cend(),
61 auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
62 lwtnn_config.input_sequences.cend(),
66 auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
67 auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
68 auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
69 if(!has_scalar_node)
ATH_MSG_WARNING(
"No scalar node with name "<<
config.input_layer_scalar<<
" found!");
70 if(!has_track_node)
ATH_MSG_WARNING(
"No track node with name "<<
config.input_layer_tracks<<
" found!");
71 if(!has_cluster_node)
ATH_MSG_WARNING(
"No cluster node with name "<<
config.input_layer_clusters<<
" found!");
74 if (has_scalar_node) {
75 for (
const auto &in : scalar_node->variables) {
76 std::string
name = in.name;
82 for (
const auto &in : track_node->variables) {
83 std::string
name = in.name;
88 if (has_cluster_node) {
89 for (
const auto &in : cluster_node->variables) {
90 std::string
name = in.name;
102 std::map<std::string, float>,
103 std::map<std::string, std::vector<char>>,
104 std::map<std::string, std::vector<float>> >
106 const std::vector<const xAOD::TauTrack *> &tracks,
107 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters)
const {
110 std::map<std::string, Inputs> gnn_input;
115 throw StatusCode::FAILURE;
119 std::vector<float> tau_feats;
123 std::vector<int64_t> tau_feats_dim = {1,
static_cast<int64_t
>(tau_feats.size())};
124 Inputs tau_info (tau_feats, tau_feats_dim);
125 gnn_input.insert({
"tau_vars", tau_info});
128 std::vector<float> trk_feats;
131 trk_feats.resize(num_nodes * num_node_vars);
134 for (
int node_idx=0; node_idx<num_nodes; node_idx++){
135 trk_feats.at(node_idx*num_node_vars + var_idx)
140 std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
141 Inputs trk_info (trk_feats, trk_feats_dim);
142 gnn_input.insert({
"track_vars", trk_info});
145 std::vector<float> cls_feats;
148 cls_feats.resize(num_nodes * num_node_vars);
151 for (
int node_idx=0; node_idx<num_nodes; node_idx++){
152 cls_feats.at(node_idx*num_node_vars + var_idx)
157 std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
158 Inputs cls_info (cls_feats, cls_feats_dim);
159 gnn_input.insert({
"cluster_vars", cls_info});
163 auto [out_f, out_vc, out_vf] =
m_onnxUtil->runInference(gnn_input);
165 return std::make_tuple(out_f, out_vc, out_vf);
169 const std::vector<const xAOD::TauTrack *> &tracks,
170 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters,
171 std::map<std::string, std::map<std::string, double>>& scalarInputs,
172 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs)
const {
173 scalarInputs.clear();
174 vectorInputs.clear();
180 <<
"' returning default");
189 <<
"' returning default");
198 <<
"' returning default");