ATLAS Offline Software
Loading...
Searching...
No Matches
InDet::TRTPIDNN Class Reference

#include <TRTPIDNN.h>

Collaboration diagram for InDet::TRTPIDNN:

Public Member Functions

 TRTPIDNN ()=default
virtual ~TRTPIDNN ()=default
const std::string & getDefaultOutputNode () const
const std::string & getDefaultOutputLabel () const
const std::map< std::string, std::map< std::string, double > > & getScalarInputs () const
const std::map< std::string, std::map< std::string, std::vector< double > > > & getVectorInputs () const
double evaluate (std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
double evaluate (std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs, const std::string &outputNode, const std::string &outputLabel) const
StatusCode configure (const std::string &json)

Private Attributes

std::unique_ptr< lwt::LightweightGraph > m_nn
lwt::GraphConfig m_nnConfig
std::map< std::string, std::map< std::string, double > > m_scalarInputs
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
std::string m_outputNode
std::string m_outputLabel

Detailed Description

Definition at line 28 of file TRTPIDNN.h.

Constructor & Destructor Documentation

◆ TRTPIDNN()

InDet::TRTPIDNN::TRTPIDNN ( )
default

◆ ~TRTPIDNN()

virtual InDet::TRTPIDNN::~TRTPIDNN ( )
virtualdefault

Member Function Documentation

◆ configure()

StatusCode InDet::TRTPIDNN::configure ( const std::string & json)

Definition at line 35 of file TRTPIDNN.cxx.

35 {
36 MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
37 try {
38 std::istringstream inputCfg(json);
39 m_nnConfig = lwt::parse_json_graph(inputCfg);
40 } catch (boost::property_tree::ptree_error& err) {
41 log << MSG::ERROR << " NN not readable: " << err.what() << endmsg;
42 return StatusCode::FAILURE;
43 }
44
45 try {
46 m_nn = std::make_unique<lwt::LightweightGraph>(m_nnConfig);
47 } catch (lwt::NNConfigurationException& err) {
48 log << MSG::ERROR << " NN configuration failed: " << err.what() << endmsg;
49 return StatusCode::FAILURE;
50 }
51
52 // set the default output node name
53 if (m_nnConfig.outputs.empty() or m_nnConfig.outputs.begin()->second.labels.empty()) {
54 log << MSG::ERROR << " unable to define NN output." << endmsg;
55 return StatusCode::FAILURE;
56 }
57 m_outputNode = m_nnConfig.outputs.begin()->first;
58 m_outputLabel = *(m_nnConfig.outputs[m_outputNode].labels.begin());
59
60 // store templates of the structure of the inputs to the NN
61 m_scalarInputs.clear();
62 for (auto input : m_nnConfig.inputs) {
63 m_scalarInputs[input.name] = {};
64 for (const auto& variable : input.variables) {
65 m_scalarInputs[input.name][variable.name] = input.defaults[variable.name];
66 }
67 }
68 m_vectorInputs.clear();
69 for (const auto& input : m_nnConfig.input_sequences) {
70 m_vectorInputs[input.name] = {};
71 for (const auto& variable : input.variables) {
72 m_vectorInputs[input.name][variable.name] = {};
73 }
74 }
75
76 return StatusCode::SUCCESS;
77}
#define endmsg
nlohmann::json json
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
Definition TRTPIDNN.h:69
std::map< std::string, std::map< std::string, double > > m_scalarInputs
Definition TRTPIDNN.h:68
lwt::GraphConfig m_nnConfig
Definition TRTPIDNN.h:67
std::string m_outputLabel
Definition TRTPIDNN.h:71
std::unique_ptr< lwt::LightweightGraph > m_nn
Definition TRTPIDNN.h:66
std::string m_outputNode
Definition TRTPIDNN.h:70
IMessageSvc * getMessageSvc(bool quiet=false)

◆ evaluate() [1/2]

double InDet::TRTPIDNN::evaluate ( std::map< std::string, std::map< std::string, double > > & scalarInputs,
std::map< std::string, std::map< std::string, std::vector< double > > > & vectorInputs ) const
inline

Definition at line 52 of file TRTPIDNN.h.

53 {
54 return evaluate(scalarInputs, vectorInputs, m_outputNode, m_outputLabel);
55 }
double evaluate(std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
Definition TRTPIDNN.h:52

◆ evaluate() [2/2]

double InDet::TRTPIDNN::evaluate ( std::map< std::string, std::map< std::string, double > > & scalarInputs,
std::map< std::string, std::map< std::string, std::vector< double > > > & vectorInputs,
const std::string & outputNode,
const std::string & outputLabel ) const

Definition at line 22 of file TRTPIDNN.cxx.

24 {
25 MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
26 const auto result = m_nn->compute(scalarInputs, vectorInputs, outputNode);
27 const auto itResult = result.find(outputLabel);
28 if (itResult == result.end()) {
29 log << MSG::ERROR << " unable to find output: node=" << outputNode << ", label=" << outputLabel << endmsg;
30 return 0.5;
31 }
32 return itResult->second;
33}
const std::string outputLabel

◆ getDefaultOutputLabel()

const std::string & InDet::TRTPIDNN::getDefaultOutputLabel ( ) const
inline

Definition at line 37 of file TRTPIDNN.h.

37 {
38 return m_outputLabel;
39 }

◆ getDefaultOutputNode()

const std::string & InDet::TRTPIDNN::getDefaultOutputNode ( ) const
inline

Definition at line 33 of file TRTPIDNN.h.

33 {
34 return m_outputNode;
35 }

◆ getScalarInputs()

const std::map< std::string, std::map< std::string, double > > & InDet::TRTPIDNN::getScalarInputs ( ) const
inline

Definition at line 42 of file TRTPIDNN.h.

42 {
43 return m_scalarInputs;
44 }

◆ getVectorInputs()

const std::map< std::string, std::map< std::string, std::vector< double > > > & InDet::TRTPIDNN::getVectorInputs ( ) const
inline

Definition at line 47 of file TRTPIDNN.h.

47 {
48 return m_vectorInputs;
49 }

Member Data Documentation

◆ m_nn

std::unique_ptr<lwt::LightweightGraph> InDet::TRTPIDNN::m_nn
private

Definition at line 66 of file TRTPIDNN.h.

◆ m_nnConfig

lwt::GraphConfig InDet::TRTPIDNN::m_nnConfig
private

Definition at line 67 of file TRTPIDNN.h.

◆ m_outputLabel

std::string InDet::TRTPIDNN::m_outputLabel
private

Definition at line 71 of file TRTPIDNN.h.

◆ m_outputNode

std::string InDet::TRTPIDNN::m_outputNode
private

Definition at line 70 of file TRTPIDNN.h.

◆ m_scalarInputs

std::map<std::string, std::map<std::string, double> > InDet::TRTPIDNN::m_scalarInputs
private

Definition at line 68 of file TRTPIDNN.h.

◆ m_vectorInputs

std::map<std::string, std::map<std::string, std::vector<double> > > InDet::TRTPIDNN::m_vectorInputs
private

Definition at line 69 of file TRTPIDNN.h.


The documentation for this class was generated from the following files: