ATLAS Offline Software
Loading...
Searching...
No Matches
TRTPIDNN.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3*/
6#include "GaudiKernel/MsgStream.h"
7#include <iostream>
8#include <memory>
9
10
11// lwtnn includes
12//#include "lwtnn/LightweightGraph.hh"
13#include "lwtnn/Exceptions.hh"
14#include "lwtnn/parse_json.hh"
15
16// JSON parsing
17#define BOOST_BIND_GLOBAL_PLACEHOLDERS // Needed to silence Boost pragma message
18#include "boost/property_tree/ptree.hpp"
19#include "boost/property_tree/json_parser.hpp"
20#include "boost/property_tree/exceptions.hpp"
21
22double InDet::TRTPIDNN::evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
23 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
24 const std::string& outputNode, const std::string& outputLabel) const {
25 MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
26 const auto result = m_nn->compute(scalarInputs, vectorInputs, outputNode);
27 const auto itResult = result.find(outputLabel);
28 if (itResult == result.end()) {
29 log << MSG::ERROR << " unable to find output: node=" << outputNode << ", label=" << outputLabel << endmsg;
30 return 0.5;
31 }
32 return itResult->second;
33}
34
35StatusCode InDet::TRTPIDNN::configure(const std::string& json) {
36 MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
37 try {
38 std::istringstream inputCfg(json);
39 m_nnConfig = lwt::parse_json_graph(inputCfg);
40 } catch (boost::property_tree::ptree_error& err) {
41 log << MSG::ERROR << " NN not readable: " << err.what() << endmsg;
42 return StatusCode::FAILURE;
43 }
44
45 try {
46 m_nn = std::make_unique<lwt::LightweightGraph>(m_nnConfig);
47 } catch (lwt::NNConfigurationException& err) {
48 log << MSG::ERROR << " NN configuration failed: " << err.what() << endmsg;
49 return StatusCode::FAILURE;
50 }
51
52 // set the default output node name
53 if (m_nnConfig.outputs.empty() or m_nnConfig.outputs.begin()->second.labels.empty()) {
54 log << MSG::ERROR << " unable to define NN output." << endmsg;
55 return StatusCode::FAILURE;
56 }
57 m_outputNode = m_nnConfig.outputs.begin()->first;
58 m_outputLabel = *(m_nnConfig.outputs[m_outputNode].labels.begin());
59
60 // store templates of the structure of the inputs to the NN
61 m_scalarInputs.clear();
62 for (auto input : m_nnConfig.inputs) {
63 m_scalarInputs[input.name] = {};
64 for (const auto& variable : input.variables) {
65 m_scalarInputs[input.name][variable.name] = input.defaults[variable.name];
66 }
67 }
68 m_vectorInputs.clear();
69 for (const auto& input : m_nnConfig.input_sequences) {
70 m_vectorInputs[input.name] = {};
71 for (const auto& variable : input.variables) {
72 m_vectorInputs[input.name][variable.name] = {};
73 }
74 }
75
76 return StatusCode::SUCCESS;
77}
#define endmsg
nlohmann::json json
const std::string outputLabel
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
Definition TRTPIDNN.h:69
std::map< std::string, std::map< std::string, double > > m_scalarInputs
Definition TRTPIDNN.h:68
lwt::GraphConfig m_nnConfig
Definition TRTPIDNN.h:67
StatusCode configure(const std::string &json)
Definition TRTPIDNN.cxx:35
std::string m_outputLabel
Definition TRTPIDNN.h:71
double evaluate(std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
Definition TRTPIDNN.h:52
std::unique_ptr< lwt::LightweightGraph > m_nn
Definition TRTPIDNN.h:66
std::string m_outputNode
Definition TRTPIDNN.h:70
singleton-like access to IMessageSvc via open function and helper
IMessageSvc * getMessageSvc(bool quiet=false)