ATLAS Offline Software
Loading...
Searching...
No Matches
RNNTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3*/
4
5// Local
8
9// External
10#include "lwtnn/parse_json.hh"
11#include "lwtnn/NanReplacer.hh"
12
13// C/C++
14#include <fstream>
15
16using namespace std;
17
18//=============================================================================
19Prompt::RNNTool::RNNTool(const std::string &name, const std::string &type, const IInterface *parent):
20 AthAlgTool(name, type, parent)
21{
22 declareInterface<Prompt::IRNNTool>(this);
23
24 declareProperty("configPathRNN", m_configPathRNN, "Path of the local RNN json file you want o study/test, it will override the PathResolverFindCalibFile file");
25 declareProperty("configRNNVersion", m_configRNNVersion, "RNN version in cvmfs");
26 declareProperty("configRNNJsonFile", m_configRNNJsonFile, "Name of the RNN json file in cvmfs");
27
28 declareProperty("inputSequenceName", m_inputSequenceName = "Trk_inputs", "Prefix of the variables used in the RNN json file");
29 declareProperty("inputSequenceSize", m_inputSequenceSize = 5, "Number of tracks used in the RNN");
30}
31
32//=============================================================================
34{
35 //
36 // Get path to xml training file
37 //
38 std::string fullPathToFile;
39
40 if(!m_configPathRNN.empty()) {
41 ATH_MSG_INFO("Override PathResolver to this path: " << m_configPathRNN);
42 fullPathToFile = m_configPathRNN;
43 }
44 else {
45 fullPathToFile = PathResolverFindCalibFile("JetTagNonPromptLepton/"
46 + m_configRNNVersion + "/"
48 }
49
50 ATH_MSG_INFO("initialize RNNTool - ConfigPathRNN: \"" << fullPathToFile);
51
52 //
53 // Configure RNN
54 //
55 std::ifstream input_stream(fullPathToFile);
56
57 lwt::GraphConfig graph_config = lwt::parse_json_graph(input_stream);
58
59 m_graph = std::make_unique<lwt::LightweightGraph>(graph_config);
60
61 for(const auto &o: graph_config.outputs) {
62 ATH_MSG_DEBUG(" output name: " << o.first << ", node_index=" << o.second.node_index);
63
64 for(const auto &l: o.second.labels) {
65 ATH_MSG_DEBUG(" label=" << l);
66
67 if(!m_outputLabels.insert(l).second) {
68 ATH_MSG_WARNING("Duplicate output label=\"" << l << "\"");
69 }
70 }
71 }
72
73 ATH_MSG_DEBUG("Number of input sequences: " << graph_config.input_sequences.size());
74
75 for(const auto &n: graph_config.input_sequences) {
76 ATH_MSG_DEBUG(" sequence name=" << n.name);
77
78 for(const lwt::Input &v: n.variables) {
79 ATH_MSG_DEBUG(" variable=" << v.name);
80 }
81 }
82
83 ATH_MSG_DEBUG("Number of inputs: " << graph_config.inputs.size());
84
85 for(const auto &n: graph_config.inputs) {
86 ATH_MSG_DEBUG(" input name=" << n.name);
87
88 for(const lwt::Input &v: n.variables) {
89 ATH_MSG_DEBUG(" variable=" << v.name);
90 }
91 }
92
93 return StatusCode::SUCCESS;
94}
95
96//=============================================================================
97std::map<std::string, double> Prompt::RNNTool::computeRNNOutput(const std::vector<Prompt::VarHolder> &tracks)
98{
99 lwt::ValueMap values;
100
101 lwt::LightweightGraph::NodeMap nodes;
102 lwt::LightweightGraph::SeqNodeMap seqs;
103
104 lwt::VectorMap &vmap = seqs[m_inputSequenceName];
105
106 AddVariable(tracks, Def::NumberOfPIXHits, vmap["m_cone_tracks_numberOfPixelHits"]);
107 AddVariable(tracks, Def::NumberOfSCTHits, vmap["m_cone_tracks_numberOfSCTHits"]);
108 AddVariable(tracks, Def::Z0Sin, vmap["m_cone_tracks_Z0Sin"]);
109 AddVariable(tracks, Def::D0Sig, vmap["m_cone_tracks_D0Sig"]);
110 AddVariable(tracks, Def::TrackJetDR, vmap["m_cone_tracks_DRTrackJet"]);
111 AddVariable(tracks, Def::TrackPtOverTrackJetPt, vmap["m_cone_tracks_PtRelOverTrackJetPt"]);
112
113 if(vmap.size() != 6) {
114 ATH_MSG_WARNING("RNNTool::computeRNNOutput - incomplete variables: return empty result");
115 return lwt::ValueMap();
116 }
117
118 unsigned nwid = 0;
119
120 for(const lwt::VectorMap::value_type &v: vmap) {
121 nwid = std::max<unsigned>(v.first.size(), nwid);
122 }
123
124 for(const lwt::VectorMap::value_type &v: vmap) {
125 ATH_MSG_DEBUG(std::setw(nwid+1) << std::left << v.first << " ");
126
127 for(const double d: v.second) {
128 ATH_MSG_DEBUG(d << ", ");
129 }
130 }
131
132 lwt::ValueMap result = m_graph->compute(nodes, seqs);
133
134 for(const lwt::ValueMap::value_type &v: result) {
135 ATH_MSG_DEBUG(v.first << " score=" << std::setprecision(10) << v.second);
136 }
137
138 return result;
139}
140
141//=============================================================================
142void Prompt::RNNTool::AddVariable(const std::vector<Prompt::VarHolder> &tracks,
143 unsigned var,
144 std::vector<double> &values)
145{
146 //
147 // Read values
148 //
149 const unsigned nvar = std::min<unsigned>(tracks.size(), m_inputSequenceSize);
150
151 for(unsigned i = 0; i < nvar; ++i) {
152 double value = 0.0;
153
154 if(i < tracks.size()) {
155 const Prompt::VarHolder &track = tracks.at(i);
156
157 if(!track.getVar(var, value)) {
158 ATH_MSG_WARNING("RNNTool::AddVariable - missing variable");
159 }
160 }
161
162 values.push_back(value);
163 }
164}
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
RNNTool(const std::string &name, const std::string &type, const IInterface *parent)
Definition RNNTool.cxx:19
virtual std::map< std::string, double > computeRNNOutput(const std::vector< Prompt::VarHolder > &tracks) override
Definition RNNTool.cxx:97
unsigned m_inputSequenceSize
Definition RNNTool.h:75
virtual StatusCode initialize() override
Definition RNNTool.cxx:33
std::string m_inputSequenceName
Definition RNNTool.h:74
void AddVariable(const std::vector< Prompt::VarHolder > &tracks, unsigned var, std::vector< double > &values)
Definition RNNTool.cxx:142
std::string m_configPathRNN
Definition RNNTool.h:70
std::set< std::string > m_outputLabels
Definition RNNTool.h:77
std::string m_configRNNJsonFile
Definition RNNTool.h:72
std::string m_configRNNVersion
Definition RNNTool.h:71
std::unique_ptr< lwt::LightweightGraph > m_graph
Definition RNNTool.h:79
@ NumberOfPIXHits
Definition VarHolder.h:48
@ NumberOfSCTHits
Definition VarHolder.h:49
@ TrackPtOverTrackJetPt
Definition VarHolder.h:55
STL namespace.