ATLAS Offline Software
Public Member Functions | Private Attributes | List of all members
InDet::TRTPIDNN Class Reference

#include <TRTPIDNN.h>

Collaboration diagram for InDet::TRTPIDNN:

Public Member Functions

 TRTPIDNN ()=default
 
virtual ~TRTPIDNN ()=default
 
const std::string & getDefaultOutputNode () const
 
const std::string & getDefaultOutputLabel () const
 
const std::map< std::string, std::map< std::string, double > > & getScalarInputs () const
 
const std::map< std::string, std::map< std::string, std::vector< double > > > & getVectorInputs () const
 
double evaluate (std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
 
double evaluate (std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs, const std::string &outputNode, const std::string &outputLabel) const
 
StatusCode configure (const std::string &json)
 

Private Attributes

std::unique_ptr< lwt::LightweightGraph > m_nn
 
lwt::GraphConfig m_nnConfig
 
std::map< std::string, std::map< std::string, double > > m_scalarInputs
 
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
 
std::string m_outputNode
 
std::string m_outputLabel
 

Detailed Description

Definition at line 28 of file TRTPIDNN.h.

Constructor & Destructor Documentation

◆ TRTPIDNN()

InDet::TRTPIDNN::TRTPIDNN ( )
default

◆ ~TRTPIDNN()

virtual InDet::TRTPIDNN::~TRTPIDNN ( )
virtualdefault

Member Function Documentation

◆ configure()

StatusCode InDet::TRTPIDNN::configure ( const std::string &  json)

Definition at line 35 of file TRTPIDNN.cxx.

35  {
36  MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
37  try {
38  std::istringstream inputCfg(json);
40  } catch (boost::property_tree::ptree_error& err) {
41  log << MSG::ERROR << " NN not readable: " << err.what() << endmsg;
42  return StatusCode::FAILURE;
43  }
44 
45  try {
46  m_nn = std::make_unique<lwt::LightweightGraph>(m_nnConfig);
47  } catch (lwt::NNConfigurationException& err) {
48  log << MSG::ERROR << " NN configuration failed: " << err.what() << endmsg;
49  return StatusCode::FAILURE;
50  }
51 
52  // set the default output node name
53  if (m_nnConfig.outputs.empty() or m_nnConfig.outputs.begin()->second.labels.empty()) {
54  log << MSG::ERROR << " unable to define NN output." << endmsg;
55  return StatusCode::FAILURE;
56  }
57  m_outputNode = m_nnConfig.outputs.begin()->first;
58  m_outputLabel = *(m_nnConfig.outputs[m_outputNode].labels.begin());
59 
60  // store templates of the structure of the inputs to the NN
61  m_scalarInputs.clear();
62  for (auto input : m_nnConfig.inputs) {
63  m_scalarInputs[input.name] = {};
64  for (const auto& variable : input.variables) {
65  m_scalarInputs[input.name][variable.name] = input.defaults[variable.name];
66  }
67  }
68  m_vectorInputs.clear();
69  for (const auto& input : m_nnConfig.input_sequences) {
70  m_vectorInputs[input.name] = {};
71  for (const auto& variable : input.variables) {
72  m_vectorInputs[input.name][variable.name] = {};
73  }
74  }
75 
76  return StatusCode::SUCCESS;
77 }

◆ evaluate() [1/2]

double InDet::TRTPIDNN::evaluate ( std::map< std::string, std::map< std::string, double >> &  scalarInputs,
std::map< std::string, std::map< std::string, std::vector< double >>> &  vectorInputs 
) const
inline

Definition at line 52 of file TRTPIDNN.h.

53  {
54  return evaluate(scalarInputs, vectorInputs, m_outputNode, m_outputLabel);
55  }

◆ evaluate() [2/2]

double InDet::TRTPIDNN::evaluate ( std::map< std::string, std::map< std::string, double >> &  scalarInputs,
std::map< std::string, std::map< std::string, std::vector< double >>> &  vectorInputs,
const std::string &  outputNode,
const std::string &  outputLabel 
) const

Definition at line 22 of file TRTPIDNN.cxx.

24  {
25  MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
26  const auto result = m_nn->compute(scalarInputs, vectorInputs, outputNode);
27  const auto itResult = result.find(outputLabel);
28  if (itResult == result.end()) {
29  log << MSG::ERROR << " unable to find output: node=" << outputNode << ", label=" << outputLabel << endmsg;
30  return 0.5;
31  }
32  return itResult->second;
33 }

◆ getDefaultOutputLabel()

const std::string& InDet::TRTPIDNN::getDefaultOutputLabel ( ) const
inline

Definition at line 37 of file TRTPIDNN.h.

37  {
38  return m_outputLabel;
39  }

◆ getDefaultOutputNode()

const std::string& InDet::TRTPIDNN::getDefaultOutputNode ( ) const
inline

Definition at line 33 of file TRTPIDNN.h.

33  {
34  return m_outputNode;
35  }

◆ getScalarInputs()

const std::map<std::string, std::map<std::string, double> >& InDet::TRTPIDNN::getScalarInputs ( ) const
inline

Definition at line 42 of file TRTPIDNN.h.

42  {
43  return m_scalarInputs;
44  }

◆ getVectorInputs()

const std::map<std::string, std::map<std::string, std::vector<double> > >& InDet::TRTPIDNN::getVectorInputs ( ) const
inline

Definition at line 47 of file TRTPIDNN.h.

47  {
48  return m_vectorInputs;
49  }

Member Data Documentation

◆ m_nn

std::unique_ptr<lwt::LightweightGraph> InDet::TRTPIDNN::m_nn
private

Definition at line 66 of file TRTPIDNN.h.

◆ m_nnConfig

lwt::GraphConfig InDet::TRTPIDNN::m_nnConfig
private

Definition at line 67 of file TRTPIDNN.h.

◆ m_outputLabel

std::string InDet::TRTPIDNN::m_outputLabel
private

Definition at line 71 of file TRTPIDNN.h.

◆ m_outputNode

std::string InDet::TRTPIDNN::m_outputNode
private

Definition at line 70 of file TRTPIDNN.h.

◆ m_scalarInputs

std::map<std::string, std::map<std::string, double> > InDet::TRTPIDNN::m_scalarInputs
private

Definition at line 68 of file TRTPIDNN.h.

◆ m_vectorInputs

std::map<std::string, std::map<std::string, std::vector<double> > > InDet::TRTPIDNN::m_vectorInputs
private

Definition at line 69 of file TRTPIDNN.h.


The documentation for this class was generated from the following files:
get_generator_info.result
result
Definition: get_generator_info.py:21
json
nlohmann::json json
Definition: HistogramDef.cxx:9
outputLabel
const std::string outputLabel
Definition: OverlapRemovalTester.cxx:69
InDet::TRTPIDNN::m_outputLabel
std::string m_outputLabel
Definition: TRTPIDNN.h:71
InDet::TRTPIDNN::m_nn
std::unique_ptr< lwt::LightweightGraph > m_nn
Definition: TRTPIDNN.h:66
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
dqt_zlumi_pandas.err
err
Definition: dqt_zlumi_pandas.py:182
endmsg
#define endmsg
Definition: AnalysisConfig_Ntuple.cxx:63
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
python.selection.variable
variable
Definition: selection.py:33
InDet::TRTPIDNN::m_nnConfig
lwt::GraphConfig m_nnConfig
Definition: TRTPIDNN.h:67
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
InDet::TRTPIDNN::m_scalarInputs
std::map< std::string, std::map< std::string, double > > m_scalarInputs
Definition: TRTPIDNN.h:68
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
InDet::TRTPIDNN::m_outputNode
std::string m_outputNode
Definition: TRTPIDNN.h:70
InDet::TRTPIDNN::evaluate
double evaluate(std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TRTPIDNN.h:52
InDet::TRTPIDNN::m_vectorInputs
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
Definition: TRTPIDNN.h:69