10 #include "lwtnn/LightweightGraph.hh"
11 #include "lwtnn/Exceptions.hh"
12 #include "lwtnn/parse_json.hh"
18 :
asg::AsgMessaging(
"TauJetRNN"), m_config(
config), m_graph(nullptr) {
21 lwt::GraphConfig lwtnn_config;
24 }
catch (
const std::logic_error &
e) {
30 auto node_is_scalar = [&
config](
const lwt::InputNodeConfig &in_node) {
31 return in_node.name ==
config.input_layer_scalar;
33 auto node_is_track = [&
config](
const lwt::InputNodeConfig &in_node) {
34 return in_node.name ==
config.input_layer_tracks;
36 auto node_is_cluster = [&
config](
const lwt::InputNodeConfig &in_node) {
37 return in_node.name ==
config.input_layer_clusters;
40 auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
41 lwtnn_config.inputs.cend(),
44 auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
45 lwtnn_config.input_sequences.cend(),
48 auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
49 lwtnn_config.input_sequences.cend(),
53 auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
54 auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
55 auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
58 if (has_scalar_node) {
59 for (
const auto &in : scalar_node->variables) {
65 for (
const auto &in : track_node->variables) {
70 if (has_cluster_node) {
71 for (
const auto &in : cluster_node->variables) {
78 m_graph = std::make_unique<lwt::LightweightGraph>(
79 lwtnn_config,
config.output_layer);
80 }
catch (
const lwt::NNConfigurationException &
e) {
92 const std::vector<const xAOD::TauTrack *> &tracks,
93 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters)
const {
100 const auto outputs =
m_graph->compute(scalarInputs, vectorInputs);
106 const std::vector<const xAOD::TauTrack *> &tracks,
107 const std::vector<xAOD::CaloVertexedTopoCluster> &
clusters,
108 std::map<std::string, std::map<std::string, double>>& scalarInputs,
109 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs)
const {
110 scalarInputs.clear();
111 vectorInputs.clear();
117 <<
"' returning default");
126 <<
"' returning default");
135 <<
"' returning default");