20 std::ifstream input_file(filename);
21 lwt::GraphConfig lwtnn_config;
23 lwtnn_config = lwt::parse_json_graph(input_file);
24 }
catch (
const std::logic_error &e) {
30 auto node_is_scalar = [&
config](
const lwt::InputNodeConfig &in_node) {
31 return in_node.name ==
config.input_layer_scalar;
33 auto node_is_track = [&
config](
const lwt::InputNodeConfig &in_node) {
34 return in_node.name ==
config.input_layer_tracks;
36 auto node_is_cluster = [&
config](
const lwt::InputNodeConfig &in_node) {
37 return in_node.name ==
config.input_layer_clusters;
40 auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
41 lwtnn_config.inputs.cend(),
44 auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
45 lwtnn_config.input_sequences.cend(),
48 auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
49 lwtnn_config.input_sequences.cend(),
53 auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
54 auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
55 auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
58 if (has_scalar_node) {
59 for (
const auto &in : scalar_node->variables) {
65 for (
const auto &in : track_node->variables) {
70 if (has_cluster_node) {
71 for (
const auto &in : cluster_node->variables) {
78 m_graph = std::make_unique<lwt::LightweightGraph>(
79 lwtnn_config,
config.output_layer);
80 }
catch (
const lwt::NNConfigurationException &e) {
92 const std::vector<const xAOD::TauTrack *> &tracks,
93 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters)
const {
100 const auto outputs =
m_graph->compute(scalarInputs, vectorInputs);
102 return outputs.at(
m_config.output_node);
106 const std::vector<const xAOD::TauTrack *> &tracks,
107 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
108 std::map<std::string, std::map<std::string, double>>& scalarInputs,
109 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs)
const {
110 scalarInputs.clear();
111 vectorInputs.clear();
115 scalarInputs[
m_config.input_layer_scalar][varname])) {
117 <<
"' returning default");
123 if (!
m_var_calc->compute(varname, tau, tracks,
124 vectorInputs[
m_config.input_layer_tracks][varname])) {
126 <<
"' returning default");
132 if (!
m_var_calc->compute(varname, tau, clusters,
133 vectorInputs[
m_config.input_layer_clusters][varname])) {
135 <<
"' returning default");
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const