ATLAS Offline Software
RNNTool.h
Go to the documentation of this file.
1 // This is -*- c++ -*-
2 
3 /*
4  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
5 */
6 
7 #ifndef ATH_RNNTOOL_H
8 #define ATH_RNNTOOL_H
9 
10 /**********************************************************************************
11  * @Package: PhysicsAnpProd
12  * @Class : RNNTool
13  * @Author : Rustem Ospanov
14  *
15  * @Brief : Tool to access RNN from lwtnn library
16  *
17  **********************************************************************************/
18 
19 // Tools
21 
22 // Local
23 #include "IRNNTool.h"
24 
25 // Athena
27 
28 // External
29 #include "lwtnn/LightweightGraph.hh"
30 
31 namespace Prompt
32 {
33  // Forward declarations
34 
35  class VarHolder;
36 
37  // Main body
38 
39  class RNNTool : public AthAlgTool, public IRNNTool
40  {
41  /*
42  RNN tool is based on the lwtnn package.
43  1. It will take the inputs VarHolder objects and convert them into lwt::VectorMap format.
44  2. lwtnn will use lwt::VectorMap format inputs with a given RNN weight json file to predict RNN scores.
45  3. Then this tool will return std::map<std::string, double> object, which contains the name string of the RNN scores and their predictions.
46 
47  */
48  public:
49 
50  RNNTool(const std::string &name,
51  const std::string &type,
52  const IInterface *parent);
53 
54  virtual StatusCode initialize() override;
55 
56  virtual std::map<std::string, double> computeRNNOutput(
57  const std::vector<Prompt::VarHolder> &tracks
58  ) override;
59 
60  virtual std::set<std::string> getOutputLabels() const override { return m_outputLabels; }
61 
62  private:
63 
64  void AddVariable(
65  const std::vector<Prompt::VarHolder> &tracks, unsigned var, std::vector<double> &values
66  );
67 
68  private:
69 
70  std::string m_configPathRNN;
71  std::string m_configRNNVersion;
72  std::string m_configRNNJsonFile;
73 
74  std::string m_inputSequenceName;
75  unsigned m_inputSequenceSize;
76 
77  std::set<std::string> m_outputLabels;
78 
79  std::unique_ptr<lwt::LightweightGraph> m_graph;
80  };
81 }
82 
83 #endif
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
Prompt
Definition: DecoratePromptLeptonImproved.h:45
Prompt::RNNTool::AddVariable
void AddVariable(const std::vector< Prompt::VarHolder > &tracks, unsigned var, std::vector< double > &values)
Definition: RNNTool.cxx:142
Prompt::RNNTool::m_graph
std::unique_ptr< lwt::LightweightGraph > m_graph
Definition: RNNTool.h:86
Prompt::RNNTool::m_inputSequenceSize
unsigned m_inputSequenceSize
Definition: RNNTool.h:82
Prompt::RNNTool::m_inputSequenceName
std::string m_inputSequenceName
Definition: RNNTool.h:81
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:797
Prompt::RNNTool::m_configRNNJsonFile
std::string m_configRNNJsonFile
Definition: RNNTool.h:79
Prompt::RNNTool::initialize
virtual StatusCode initialize() override
Definition: RNNTool.cxx:33
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
test_pyathena.parent
parent
Definition: test_pyathena.py:15
Prompt::RNNTool::RNNTool
RNNTool(const std::string &name, const std::string &type, const IInterface *parent)
Definition: RNNTool.cxx:19
Prompt::RNNTool::getOutputLabels
virtual std::set< std::string > getOutputLabels() const override
Definition: RNNTool.h:67
Prompt::RNNTool::m_configRNNVersion
std::string m_configRNNVersion
Definition: RNNTool.h:78
Prompt::RNNTool::m_outputLabels
std::set< std::string > m_outputLabels
Definition: RNNTool.h:84
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
IRNNTool.h
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
Prompt::RNNTool::m_configPathRNN
std::string m_configPathRNN
Definition: RNNTool.h:77
AthAlgTool
Definition: AthAlgTool.h:26
Prompt::RNNTool::computeRNNOutput
virtual std::map< std::string, double > computeRNNOutput(const std::vector< Prompt::VarHolder > &tracks) override
Definition: RNNTool.cxx:97