ATLAS Offline Software
TRTPIDNN.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
6 #include "GaudiKernel/MsgStream.h"
7 #include <iostream>
8 #include <memory>
9 
10 
11 // lwtnn includes
12 //#include "lwtnn/LightweightGraph.hh"
13 #include "lwtnn/Exceptions.hh"
14 #include "lwtnn/parse_json.hh"
15 
16 // JSON parsing
17 #define BOOST_BIND_GLOBAL_PLACEHOLDERS // Needed to silence Boost pragma message
18 #include "boost/property_tree/ptree.hpp"
19 #include "boost/property_tree/json_parser.hpp"
20 #include "boost/property_tree/exceptions.hpp"
21 
22 double InDet::TRTPIDNN::evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
23  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
24  const std::string& outputNode, const std::string& outputLabel) const {
25  MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
26  const auto result = m_nn->compute(scalarInputs, vectorInputs, outputNode);
27  const auto itResult = result.find(outputLabel);
28  if (itResult == result.end()) {
29  log << MSG::ERROR << " unable to find output: node=" << outputNode << ", label=" << outputLabel << endmsg;
30  return 0.5;
31  }
32  return itResult->second;
33 }
34 
36  MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
37  try {
38  std::istringstream inputCfg(json);
40  } catch (boost::property_tree::ptree_error& err) {
41  log << MSG::ERROR << " NN not readable: " << err.what() << endmsg;
42  return StatusCode::FAILURE;
43  }
44 
45  try {
46  m_nn = std::make_unique<lwt::LightweightGraph>(m_nnConfig);
47  } catch (lwt::NNConfigurationException& err) {
48  log << MSG::ERROR << " NN configuration failed: " << err.what() << endmsg;
49  return StatusCode::FAILURE;
50  }
51 
52  // set the default output node name
53  if (m_nnConfig.outputs.empty() or m_nnConfig.outputs.begin()->second.labels.empty()) {
54  log << MSG::ERROR << " unable to define NN output." << endmsg;
55  return StatusCode::FAILURE;
56  }
57  m_outputNode = m_nnConfig.outputs.begin()->first;
58  m_outputLabel = *(m_nnConfig.outputs[m_outputNode].labels.begin());
59 
60  // store templates of the structure of the inputs to the NN
61  m_scalarInputs.clear();
62  for (auto input : m_nnConfig.inputs) {
63  m_scalarInputs[input.name] = {};
64  for (const auto& variable : input.variables) {
65  m_scalarInputs[input.name][variable.name] = input.defaults[variable.name];
66  }
67  }
68  m_vectorInputs.clear();
69  for (const auto& input : m_nnConfig.input_sequences) {
70  m_vectorInputs[input.name] = {};
71  for (const auto& variable : input.variables) {
72  m_vectorInputs[input.name][variable.name] = {};
73  }
74  }
75 
76  return StatusCode::SUCCESS;
77 }
get_generator_info.result
result
Definition: get_generator_info.py:21
getMessageSvc.h
singleton-like access to IMessageSvc via open function and helper
TRTPIDNN.h
json
nlohmann::json json
Definition: HistogramDef.cxx:9
outputLabel
const std::string outputLabel
Definition: OverlapRemovalTester.cxx:69
InDet::TRTPIDNN::m_outputLabel
std::string m_outputLabel
Definition: TRTPIDNN.h:71
InDet::TRTPIDNN::configure
StatusCode configure(const std::string &json)
Definition: TRTPIDNN.cxx:35
InDet::TRTPIDNN::m_nn
std::unique_ptr< lwt::LightweightGraph > m_nn
Definition: TRTPIDNN.h:66
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
dqt_zlumi_pandas.err
err
Definition: dqt_zlumi_pandas.py:182
endmsg
#define endmsg
Definition: AnalysisConfig_Ntuple.cxx:63
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
python.selection.variable
variable
Definition: selection.py:33
InDet::TRTPIDNN::m_nnConfig
lwt::GraphConfig m_nnConfig
Definition: TRTPIDNN.h:67
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
InDet::TRTPIDNN::m_scalarInputs
std::map< std::string, std::map< std::string, double > > m_scalarInputs
Definition: TRTPIDNN.h:68
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
InDet::TRTPIDNN::m_outputNode
std::string m_outputNode
Definition: TRTPIDNN.h:70
InDet::TRTPIDNN::evaluate
double evaluate(std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TRTPIDNN.h:52
InDet::TRTPIDNN::m_vectorInputs
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
Definition: TRTPIDNN.h:69