ATLAS Offline Software
TRTPIDNN.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 #ifndef INDETTRTPIDNN_H
5 #define INDETTRTPIDNN_H
6 
8 // TRTPIDNN.h, (c) ATLAS Detector software
10 
11 /****************************************************************************************\
12 
13  Class to wrap the lwtnn instance of the TRT PID NN. It is instantiated in PIDNNCondAlg.
14 
15  Author: Christian Grefe (christian.grefe@cern.ch)
16 
17 \****************************************************************************************/
18 #include "GaudiKernel/StatusCode.h"
19 #include "AthenaKernel/CLASS_DEF.h"
20 #include "AthenaKernel/CondCont.h"
21 #include "lwtnn/LightweightGraph.hh"
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 
27 namespace InDet {
28  class TRTPIDNN {
29  public:
30  TRTPIDNN()=default;
31  virtual ~TRTPIDNN()=default;
32 
33  const std::string& getDefaultOutputNode() const {
34  return m_outputNode;
35  }
36 
37  const std::string& getDefaultOutputLabel() const {
38  return m_outputLabel;
39  }
40 
41  // get the structure of the scalar inputs to the NN
42  const std::map<std::string, std::map<std::string, double>>& getScalarInputs() const {
43  return m_scalarInputs;
44  }
45 
46  // get the structure of the vector inputs to the NN
47  const std::map<std::string, std::map<std::string, std::vector<double>>>& getVectorInputs() const {
48  return m_vectorInputs;
49  }
50 
51  // calculate NN response for default output node and label
52  double evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
53  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
54  return evaluate(scalarInputs, vectorInputs, m_outputNode, m_outputLabel);
55  }
56 
57  // calculate NN response
58  double evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
59  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
60  const std::string& outputNode, const std::string& outputLabel) const;
61 
62  // set up the NN
63  StatusCode configure(const std::string& json);
64 
65  private:
66  std::unique_ptr<lwt::LightweightGraph> m_nn; // the NN
67  lwt::GraphConfig m_nnConfig; // cofiguration of the NN
68  std::map<std::string, std::map<std::string, double>> m_scalarInputs; // template for the structure of the scalar inputs to the NN
69  std::map<std::string, std::map<std::string, std::vector<double>>> m_vectorInputs; // template for the structure of the vector inputs to the NN
70  std::string m_outputNode; // name of the output node of the NN
71  std::string m_outputLabel; // name of the output label of the NN
72 };
73 }
74 CLASS_DEF(InDet::TRTPIDNN,341715853,1)
75 CONDCONT_DEF(InDet::TRTPIDNN,710491600);
76 
77 #endif
InDet::TRTPIDNN::getDefaultOutputNode
const std::string & getDefaultOutputNode() const
Definition: TRTPIDNN.h:33
InDet::TRTPIDNN
Definition: TRTPIDNN.h:28
CondCont.h
Hold mappings of ranges to condition objects.
json
nlohmann::json json
Definition: HistogramDef.cxx:9
InDet
Primary Vertex Finder.
Definition: VP1ErrorUtils.h:36
outputLabel
const std::string outputLabel
Definition: OverlapRemovalTester.cxx:69
InDet::TRTPIDNN::m_outputLabel
std::string m_outputLabel
Definition: TRTPIDNN.h:71
InDet::TRTPIDNN::configure
StatusCode configure(const std::string &json)
Definition: TRTPIDNN.cxx:35
InDet::TRTPIDNN::getVectorInputs
const std::map< std::string, std::map< std::string, std::vector< double > > > & getVectorInputs() const
Definition: TRTPIDNN.h:47
InDet::TRTPIDNN::m_nn
std::unique_ptr< lwt::LightweightGraph > m_nn
Definition: TRTPIDNN.h:66
CONDCONT_DEF
CONDCONT_DEF(InDet::TRTPIDNN, 710491600)
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
InDet::TRTPIDNN::getScalarInputs
const std::map< std::string, std::map< std::string, double > > & getScalarInputs() const
Definition: TRTPIDNN.h:42
InDet::TRTPIDNN::TRTPIDNN
TRTPIDNN()=default
InDet::TRTPIDNN::m_nnConfig
lwt::GraphConfig m_nnConfig
Definition: TRTPIDNN.h:67
CLASS_DEF
#define CLASS_DEF(NAME, CID, VERSION)
associate a clid and a version to a type eg
Definition: Control/AthenaKernel/AthenaKernel/CLASS_DEF.h:64
InDet::TRTPIDNN::getDefaultOutputLabel
const std::string & getDefaultOutputLabel() const
Definition: TRTPIDNN.h:37
InDet::TRTPIDNN::m_scalarInputs
std::map< std::string, std::map< std::string, double > > m_scalarInputs
Definition: TRTPIDNN.h:68
InDet::TRTPIDNN::m_outputNode
std::string m_outputNode
Definition: TRTPIDNN.h:70
InDet::TRTPIDNN::evaluate
double evaluate(std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TRTPIDNN.h:52
InDet::TRTPIDNN::~TRTPIDNN
virtual ~TRTPIDNN()=default
CLASS_DEF.h
macros to associate a CLID to a type
InDet::TRTPIDNN::m_vectorInputs
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
Definition: TRTPIDNN.h:69