ATLAS Offline Software
Loading...
Searching...
No Matches
Prompt::RNNTool Class Reference

#include <RNNTool.h>

Inheritance diagram for Prompt::RNNTool:
Collaboration diagram for Prompt::RNNTool:

Public Member Functions

 RNNTool (const std::string &name, const std::string &type, const IInterface *parent)
virtual StatusCode initialize () override
virtual std::map< std::string, double > computeRNNOutput (const std::vector< Prompt::VarHolder > &tracks) override
virtual std::set< std::string > getOutputLabels () const override
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const

Static Public Member Functions

static const InterfaceID & interfaceID ()

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

void AddVariable (const std::vector< Prompt::VarHolder > &tracks, unsigned var, std::vector< double > &values)
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>

Private Attributes

std::string m_configPathRNN
std::string m_configRNNVersion
std::string m_configRNNJsonFile
std::string m_inputSequenceName
unsigned m_inputSequenceSize
std::set< std::string > m_outputLabels
std::unique_ptr< lwt::LightweightGraph > m_graph
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default)
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default)
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared

Detailed Description

Definition at line 39 of file RNNTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ RNNTool()

Prompt::RNNTool::RNNTool ( const std::string & name,
const std::string & type,
const IInterface * parent )

Definition at line 19 of file RNNTool.cxx.

19 :
20 AthAlgTool(name, type, parent)
21{
22 declareInterface<Prompt::IRNNTool>(this);
23
24 declareProperty("configPathRNN", m_configPathRNN, "Path of the local RNN json file you want o study/test, it will override the PathResolverFindCalibFile file");
25 declareProperty("configRNNVersion", m_configRNNVersion, "RNN version in cvmfs");
26 declareProperty("configRNNJsonFile", m_configRNNJsonFile, "Name of the RNN json file in cvmfs");
27
28 declareProperty("inputSequenceName", m_inputSequenceName = "Trk_inputs", "Prefix of the variables used in the RNN json file");
29 declareProperty("inputSequenceSize", m_inputSequenceSize = 5, "Number of tracks used in the RNN");
30}
AthAlgTool()
Default constructor:
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
unsigned m_inputSequenceSize
Definition RNNTool.h:75
std::string m_inputSequenceName
Definition RNNTool.h:74
std::string m_configPathRNN
Definition RNNTool.h:70
std::string m_configRNNJsonFile
Definition RNNTool.h:72
std::string m_configRNNVersion
Definition RNNTool.h:71

Member Function Documentation

◆ AddVariable()

void Prompt::RNNTool::AddVariable ( const std::vector< Prompt::VarHolder > & tracks,
unsigned var,
std::vector< double > & values )
private

Definition at line 142 of file RNNTool.cxx.

145{
146 //
147 // Read values
148 //
149 const unsigned nvar = std::min<unsigned>(tracks.size(), m_inputSequenceSize);
150
151 for(unsigned i = 0; i < nvar; ++i) {
152 double value = 0.0;
153
154 if(i < tracks.size()) {
155 const Prompt::VarHolder &track = tracks.at(i);
156
157 if(!track.getVar(var, value)) {
158 ATH_MSG_WARNING("RNNTool::AddVariable - missing variable");
159 }
160 }
161
162 values.push_back(value);
163 }
164}
#define ATH_MSG_WARNING(x)

◆ computeRNNOutput()

std::map< std::string, double > Prompt::RNNTool::computeRNNOutput ( const std::vector< Prompt::VarHolder > & tracks)
overridevirtual

Implements Prompt::IRNNTool.

Definition at line 97 of file RNNTool.cxx.

98{
99 lwt::ValueMap values;
100
101 lwt::LightweightGraph::NodeMap nodes;
102 lwt::LightweightGraph::SeqNodeMap seqs;
103
104 lwt::VectorMap &vmap = seqs[m_inputSequenceName];
105
106 AddVariable(tracks, Def::NumberOfPIXHits, vmap["m_cone_tracks_numberOfPixelHits"]);
107 AddVariable(tracks, Def::NumberOfSCTHits, vmap["m_cone_tracks_numberOfSCTHits"]);
108 AddVariable(tracks, Def::Z0Sin, vmap["m_cone_tracks_Z0Sin"]);
109 AddVariable(tracks, Def::D0Sig, vmap["m_cone_tracks_D0Sig"]);
110 AddVariable(tracks, Def::TrackJetDR, vmap["m_cone_tracks_DRTrackJet"]);
111 AddVariable(tracks, Def::TrackPtOverTrackJetPt, vmap["m_cone_tracks_PtRelOverTrackJetPt"]);
112
113 if(vmap.size() != 6) {
114 ATH_MSG_WARNING("RNNTool::computeRNNOutput - incomplete variables: return empty result");
115 return lwt::ValueMap();
116 }
117
118 unsigned nwid = 0;
119
120 for(const lwt::VectorMap::value_type &v: vmap) {
121 nwid = std::max<unsigned>(v.first.size(), nwid);
122 }
123
124 for(const lwt::VectorMap::value_type &v: vmap) {
125 ATH_MSG_DEBUG(std::setw(nwid+1) << std::left << v.first << " ");
126
127 for(const double d: v.second) {
128 ATH_MSG_DEBUG(d << ", ");
129 }
130 }
131
132 lwt::ValueMap result = m_graph->compute(nodes, seqs);
133
134 for(const lwt::ValueMap::value_type &v: result) {
135 ATH_MSG_DEBUG(v.first << " score=" << std::setprecision(10) << v.second);
136 }
137
138 return result;
139}
#define ATH_MSG_DEBUG(x)
void AddVariable(const std::vector< Prompt::VarHolder > &tracks, unsigned var, std::vector< double > &values)
Definition RNNTool.cxx:142
std::unique_ptr< lwt::LightweightGraph > m_graph
Definition RNNTool.h:79
@ NumberOfPIXHits
Definition VarHolder.h:48
@ NumberOfSCTHits
Definition VarHolder.h:49
@ TrackPtOverTrackJetPt
Definition VarHolder.h:55

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158 {
160 hndl.value(),
161 hndl.documentation());
162
163 }

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145 {
146 typedef typename SG::HandleClassifier<T>::type htype;
148 }
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ getOutputLabels()

virtual std::set< std::string > Prompt::RNNTool::getOutputLabels ( ) const
inlineoverridevirtual

Implements Prompt::IRNNTool.

Definition at line 60 of file RNNTool.h.

60{ return m_outputLabels; }
std::set< std::string > m_outputLabels
Definition RNNTool.h:77

◆ initialize()

StatusCode Prompt::RNNTool::initialize ( )
overridevirtual

Implements Prompt::IRNNTool.

Definition at line 33 of file RNNTool.cxx.

34{
35 //
36 // Get path to xml training file
37 //
38 std::string fullPathToFile;
39
40 if(!m_configPathRNN.empty()) {
41 ATH_MSG_INFO("Override PathResolver to this path: " << m_configPathRNN);
42 fullPathToFile = m_configPathRNN;
43 }
44 else {
45 fullPathToFile = PathResolverFindCalibFile("JetTagNonPromptLepton/"
46 + m_configRNNVersion + "/"
48 }
49
50 ATH_MSG_INFO("initialize RNNTool - ConfigPathRNN: \"" << fullPathToFile);
51
52 //
53 // Configure RNN
54 //
55 std::ifstream input_stream(fullPathToFile);
56
57 lwt::GraphConfig graph_config = lwt::parse_json_graph(input_stream);
58
59 m_graph = std::make_unique<lwt::LightweightGraph>(graph_config);
60
61 for(const auto &o: graph_config.outputs) {
62 ATH_MSG_DEBUG(" output name: " << o.first << ", node_index=" << o.second.node_index);
63
64 for(const auto &l: o.second.labels) {
65 ATH_MSG_DEBUG(" label=" << l);
66
67 if(!m_outputLabels.insert(l).second) {
68 ATH_MSG_WARNING("Duplicate output label=\"" << l << "\"");
69 }
70 }
71 }
72
73 ATH_MSG_DEBUG("Number of input sequences: " << graph_config.input_sequences.size());
74
75 for(const auto &n: graph_config.input_sequences) {
76 ATH_MSG_DEBUG(" sequence name=" << n.name);
77
78 for(const lwt::Input &v: n.variables) {
79 ATH_MSG_DEBUG(" variable=" << v.name);
80 }
81 }
82
83 ATH_MSG_DEBUG("Number of inputs: " << graph_config.inputs.size());
84
85 for(const auto &n: graph_config.inputs) {
86 ATH_MSG_DEBUG(" input name=" << n.name);
87
88 for(const lwt::Input &v: n.variables) {
89 ATH_MSG_DEBUG(" variable=" << v.name);
90 }
91 }
92
93 return StatusCode::SUCCESS;
94}
#define ATH_MSG_INFO(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ interfaceID()

const InterfaceID & Prompt::IRNNTool::interfaceID ( )
inlinestaticinherited

Definition at line 44 of file IRNNTool.h.

44{ return IID_IRNNTool; }
static const InterfaceID IID_IRNNTool("Prompt::IRNNTool", 1, 0)

◆ msg()

MsgStream & AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24 {
25 return this->msgStream();
26 }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30 {
31 return this->msgLevel(lvl);
32 }

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381 {
382 h.renounce();
384 }
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364 {
366 }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in asg::AsgMetadataTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and DerivationFramework::CfAthAlgTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308 {
309 // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310 // << " size: " << m_vhka.size() << endmsg;
311 for (auto &a : m_vhka) {
313 for (auto k : keys) {
314 k->setOwner(this);
315 }
316 }
317 }
std::vector< SG::VarHandleKeyArray * > m_vhka

Member Data Documentation

◆ m_configPathRNN

std::string Prompt::RNNTool::m_configPathRNN
private

Definition at line 70 of file RNNTool.h.

◆ m_configRNNJsonFile

std::string Prompt::RNNTool::m_configRNNJsonFile
private

Definition at line 72 of file RNNTool.h.

◆ m_configRNNVersion

std::string Prompt::RNNTool::m_configRNNVersion
private

Definition at line 71 of file RNNTool.h.

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_graph

std::unique_ptr<lwt::LightweightGraph> Prompt::RNNTool::m_graph
private

Definition at line 79 of file RNNTool.h.

◆ m_inputSequenceName

std::string Prompt::RNNTool::m_inputSequenceName
private

Definition at line 74 of file RNNTool.h.

◆ m_inputSequenceSize

unsigned Prompt::RNNTool::m_inputSequenceSize
private

Definition at line 75 of file RNNTool.h.

◆ m_outputLabels

std::set<std::string> Prompt::RNNTool::m_outputLabels
private

Definition at line 77 of file RNNTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: