7 #include "lwtnn/parse_json.hh"
16 asg::AsgMessaging(
"TauGNN"),
25 FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
28 for (
const auto& out_node: gnn_output_config) {
35 auto lwtnn_config = m_onnxUtil->getLwtConfig();
42 auto node_is_scalar = [&
config](
const lwt::InputNodeConfig &in_node) {
43 return in_node.name ==
config.input_layer_scalar;
45 auto node_is_track = [&
config](
const lwt::InputNodeConfig &in_node) {
46 return in_node.name ==
config.input_layer_tracks;
48 auto node_is_cluster = [&
config](
const lwt::InputNodeConfig &in_node) {
49 return in_node.name ==
config.input_layer_clusters;
52 auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
53 lwtnn_config.inputs.cend(),
56 auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
57 lwtnn_config.input_sequences.cend(),
60 auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
61 lwtnn_config.input_sequences.cend(),
65 auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
66 auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
67 auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
68 if(!has_scalar_node)
ATH_MSG_WARNING(
"No scalar node with name "<<
config.input_layer_scalar<<
" found!");
69 if(!has_track_node)
ATH_MSG_WARNING(
"No track node with name "<<
config.input_layer_tracks<<
" found!");
70 if(!has_cluster_node)
ATH_MSG_WARNING(
"No cluster node with name "<<
config.input_layer_clusters<<
" found!");
73 if (has_scalar_node) {
74 for (
const auto &in : scalar_node->variables) {
75 std::string
name = in.name;
76 m_scalarCalc_inputs.push_back(
name);
81 for (
const auto &in : track_node->variables) {
82 std::string
name = in.name;
83 m_trackCalc_inputs.push_back(
name);
87 if (has_cluster_node) {
88 for (
const auto &in : cluster_node->variables) {
89 std::string
name = in.name;
90 m_clusterCalc_inputs.push_back(
name);
101 std::map<std::string, float>,
102 std::map<std::string, std::vector<char>>,
103 std::map<std::string, std::vector<float>> >
105 const std::vector<const xAOD::TauTrack *> &tracks,
106 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters)
const {
109 std::map<std::string, Inputs> gnn_input;
114 throw StatusCode::FAILURE;
118 std::vector<float> tau_feats;
122 std::vector<int64_t> tau_feats_dim = {1,
static_cast<int64_t
>(tau_feats.size())};
123 Inputs tau_info (tau_feats, tau_feats_dim);
124 gnn_input.insert({
"tau_vars", tau_info});
127 std::vector<float> trk_feats;
130 trk_feats.resize(num_nodes * num_node_vars);
133 for (
int node_idx=0; node_idx<num_nodes; node_idx++){
134 trk_feats.at(node_idx*num_node_vars + var_idx)
139 std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
140 Inputs trk_info (trk_feats, trk_feats_dim);
141 gnn_input.insert({
"track_vars", trk_info});
144 std::vector<float> cls_feats;
147 cls_feats.resize(num_nodes * num_node_vars);
150 for (
int node_idx=0; node_idx<num_nodes; node_idx++){
151 cls_feats.at(node_idx*num_node_vars + var_idx)
156 std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
157 Inputs cls_info (cls_feats, cls_feats_dim);
158 gnn_input.insert({
"cluster_vars", cls_info});
162 auto [out_f, out_vc, out_vf] =
m_onnxUtil->runInference(gnn_input);
164 return std::make_tuple(out_f, out_vc, out_vf);
168 const std::vector<const xAOD::TauTrack *> &tracks,
169 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters,
170 std::map<std::string, std::map<std::string, double>>& scalarInputs,
171 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs)
const {
172 scalarInputs.clear();
173 vectorInputs.clear();
179 <<
"' returning default");
188 <<
"' returning default");
197 <<
"' returning default");