ATLAS Offline Software
RNNTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 // Local
8 
9 // External
10 #include "lwtnn/parse_json.hh"
11 #include "lwtnn/NanReplacer.hh"
12 
13 // C/C++
14 #include <fstream>
15 
16 using namespace std;
17 
18 //=============================================================================
19 Prompt::RNNTool::RNNTool(const std::string &name, const std::string &type, const IInterface *parent):
21 {
22  declareInterface<Prompt::IRNNTool>(this);
23 
24  declareProperty("configPathRNN", m_configPathRNN, "Path of the local RNN json file you want o study/test, it will override the PathResolverFindCalibFile file");
25  declareProperty("configRNNVersion", m_configRNNVersion, "RNN version in cvmfs");
26  declareProperty("configRNNJsonFile", m_configRNNJsonFile, "Name of the RNN json file in cvmfs");
27 
28  declareProperty("inputSequenceName", m_inputSequenceName = "Trk_inputs", "Prefix of the variables used in the RNN json file");
29  declareProperty("inputSequenceSize", m_inputSequenceSize = 5, "Number of tracks used in the RNN");
30 }
31 
32 //=============================================================================
34 {
35  //
36  // Get path to xml training file
37  //
38  std::string fullPathToFile;
39 
40  if(!m_configPathRNN.empty()) {
41  ATH_MSG_INFO("Override PathResolver to this path: " << m_configPathRNN);
42  fullPathToFile = m_configPathRNN;
43  }
44  else {
45  fullPathToFile = PathResolverFindCalibFile("JetTagNonPromptLepton/"
46  + m_configRNNVersion + "/"
47  + m_configRNNJsonFile);
48  }
49 
50  ATH_MSG_INFO("initialize RNNTool - ConfigPathRNN: \"" << fullPathToFile);
51 
52  //
53  // Configure RNN
54  //
55  std::ifstream input_stream(fullPathToFile);
56 
57  lwt::GraphConfig graph_config = lwt::parse_json_graph(input_stream);
58 
59  m_graph = std::make_unique<lwt::LightweightGraph>(graph_config);
60 
61  for(const auto &o: graph_config.outputs) {
62  ATH_MSG_DEBUG(" output name: " << o.first << ", node_index=" << o.second.node_index);
63 
64  for(const auto &l: o.second.labels) {
65  ATH_MSG_DEBUG(" label=" << l);
66 
67  if(!m_outputLabels.insert(l).second) {
68  ATH_MSG_WARNING("Duplicate output label=\"" << l << "\"");
69  }
70  }
71  }
72 
73  ATH_MSG_DEBUG("Number of input sequences: " << graph_config.input_sequences.size());
74 
75  for(const auto &n: graph_config.input_sequences) {
76  ATH_MSG_DEBUG(" sequence name=" << n.name);
77 
78  for(const lwt::Input &v: n.variables) {
79  ATH_MSG_DEBUG(" variable=" << v.name);
80  }
81  }
82 
83  ATH_MSG_DEBUG("Number of inputs: " << graph_config.inputs.size());
84 
85  for(const auto &n: graph_config.inputs) {
86  ATH_MSG_DEBUG(" input name=" << n.name);
87 
88  for(const lwt::Input &v: n.variables) {
89  ATH_MSG_DEBUG(" variable=" << v.name);
90  }
91  }
92 
93  return StatusCode::SUCCESS;
94 }
95 
96 //=============================================================================
97 std::map<std::string, double> Prompt::RNNTool::computeRNNOutput(const std::vector<Prompt::VarHolder> &tracks)
98 {
100 
103 
104  lwt::VectorMap &vmap = seqs[m_inputSequenceName];
105 
106  AddVariable(tracks, Def::NumberOfPIXHits, vmap["m_cone_tracks_numberOfPixelHits"]);
107  AddVariable(tracks, Def::NumberOfSCTHits, vmap["m_cone_tracks_numberOfSCTHits"]);
108  AddVariable(tracks, Def::Z0Sin, vmap["m_cone_tracks_Z0Sin"]);
109  AddVariable(tracks, Def::D0Sig, vmap["m_cone_tracks_D0Sig"]);
110  AddVariable(tracks, Def::TrackJetDR, vmap["m_cone_tracks_DRTrackJet"]);
111  AddVariable(tracks, Def::TrackPtOverTrackJetPt, vmap["m_cone_tracks_PtRelOverTrackJetPt"]);
112 
113  if(vmap.size() != 6) {
114  ATH_MSG_WARNING("RNNTool::computeRNNOutput - incomplete variables: return empty result");
115  return lwt::ValueMap();
116  }
117 
118  unsigned nwid = 0;
119 
120  for(const lwt::VectorMap::value_type &v: vmap) {
121  nwid = std::max<unsigned>(v.first.size(), nwid);
122  }
123 
124  for(const lwt::VectorMap::value_type &v: vmap) {
125  ATH_MSG_DEBUG(std::setw(nwid+1) << std::left << v.first << " ");
126 
127  for(const double d: v.second) {
128  ATH_MSG_DEBUG(d << ", ");
129  }
130  }
131 
132  lwt::ValueMap result = m_graph->compute(nodes, seqs);
133 
134  for(const lwt::ValueMap::value_type &v: result) {
135  ATH_MSG_DEBUG(v.first << " score=" << std::setprecision(10) << v.second);
136  }
137 
138  return result;
139 }
140 
141 //=============================================================================
142 void Prompt::RNNTool::AddVariable(const std::vector<Prompt::VarHolder> &tracks,
143  unsigned var,
144  std::vector<double> &values)
145 {
146  //
147  // Read values
148  //
149  const unsigned nvar = std::min<unsigned>(tracks.size(), m_inputSequenceSize);
150 
151  for(unsigned i = 0; i < nvar; ++i) {
152  double value = 0.0;
153 
154  if(i < tracks.size()) {
155  const Prompt::VarHolder &track = tracks.at(i);
156 
157  if(!track.getVar(var, value)) {
158  ATH_MSG_WARNING("RNNTool::AddVariable - missing variable");
159  }
160  }
161 
162  values.push_back(value);
163  }
164 }
RNNTool.h
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
Prompt::Def::NumberOfSCTHits
@ NumberOfSCTHits
Definition: VarHolder.h:79
get_generator_info.result
result
Definition: get_generator_info.py:21
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
hist_file_dump.d
d
Definition: hist_file_dump.py:137
Prompt::RNNTool::AddVariable
void AddVariable(const std::vector< Prompt::VarHolder > &tracks, unsigned var, std::vector< double > &values)
Definition: RNNTool.cxx:142
Prompt::VarHolder
Definition: VarHolder.h:112
athena.value
value
Definition: athena.py:124
UploadAMITag.l
list l
Definition: UploadAMITag.larcaf.py:158
Prompt::RNNTool::m_inputSequenceSize
unsigned m_inputSequenceSize
Definition: RNNTool.h:82
lwtDev::NodeMap
LightweightGraph::NodeMap NodeMap
Definition: LightweightGraph.cxx:67
Prompt::RNNTool::m_inputSequenceName
std::string m_inputSequenceName
Definition: RNNTool.h:81
Prompt::Def::Z0Sin
@ Z0Sin
Definition: VarHolder.h:86
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
Prompt::Def::NumberOfPIXHits
@ NumberOfPIXHits
Definition: VarHolder.h:78
Prompt::RNNTool::m_configRNNJsonFile
std::string m_configRNNJsonFile
Definition: RNNTool.h:79
Prompt::Def::D0Sig
@ D0Sig
Definition: VarHolder.h:87
lumiFormat.i
int i
Definition: lumiFormat.py:85
beamspotman.n
n
Definition: beamspotman.py:731
Prompt::RNNTool::initialize
virtual StatusCode initialize() override
Definition: RNNTool.cxx:33
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
test_pyathena.parent
parent
Definition: test_pyathena.py:15
Prompt::RNNTool::RNNTool
RNNTool(const std::string &name, const std::string &type, const IInterface *parent)
Definition: RNNTool.cxx:19
Prompt::RNNTool::m_configRNNVersion
std::string m_configRNNVersion
Definition: RNNTool.h:78
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
VarHolder.h
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauDecayModeNNClassifier.cxx:23
python.PyAthena.v
v
Definition: PyAthena.py:154
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
Prompt::Def::TrackPtOverTrackJetPt
@ TrackPtOverTrackJetPt
Definition: VarHolder.h:85
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
Prompt::Def::TrackJetDR
@ TrackJetDR
Definition: VarHolder.h:84
Prompt::RNNTool::m_configPathRNN
std::string m_configPathRNN
Definition: RNNTool.h:77
xAOD::track
@ track
Definition: TrackingPrimitives.h:512
tauRecTools::SeqNodeMap
std::map< std::string, VectorMap > SeqNodeMap
Definition: TauTrackRNNClassifier.h:44
AthAlgTool
Definition: AthAlgTool.h:26
Prompt::RNNTool::computeRNNOutput
virtual std::map< std::string, double > computeRNNOutput(const std::vector< Prompt::VarHolder > &tracks) override
Definition: RNNTool.cxx:97
ValueMap
std::map< std::string, double > ValueMap
Definition: TauDecayModeNNClassifier.cxx:22
Input
NswErrorCalibData::Input Input
Definition: NswErrorCalibData.cxx:6