6#include "GaudiKernel/MsgStream.h"
13#include "lwtnn/Exceptions.hh"
14#include "lwtnn/parse_json.hh"
17#define BOOST_BIND_GLOBAL_PLACEHOLDERS
18#include "boost/property_tree/ptree.hpp"
19#include "boost/property_tree/json_parser.hpp"
20#include "boost/property_tree/exceptions.hpp"
23 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
24 const std::string& outputNode,
const std::string&
outputLabel)
const {
26 const auto result =
m_nn->compute(scalarInputs, vectorInputs, outputNode);
28 if (itResult ==
result.end()) {
29 log << MSG::ERROR <<
" unable to find output: node=" << outputNode <<
", label=" <<
outputLabel <<
endmsg;
32 return itResult->second;
38 std::istringstream inputCfg(
json);
40 }
catch (boost::property_tree::ptree_error& err) {
41 log << MSG::ERROR <<
" NN not readable: " << err.what() <<
endmsg;
42 return StatusCode::FAILURE;
47 }
catch (lwt::NNConfigurationException& err) {
48 log << MSG::ERROR <<
" NN configuration failed: " << err.what() <<
endmsg;
49 return StatusCode::FAILURE;
54 log << MSG::ERROR <<
" unable to define NN output." <<
endmsg;
55 return StatusCode::FAILURE;
64 for (
const auto& variable : input.variables) {
65 m_scalarInputs[input.name][variable.name] = input.defaults[variable.name];
69 for (
const auto& input :
m_nnConfig.input_sequences) {
71 for (
const auto& variable : input.variables) {
76 return StatusCode::SUCCESS;
const std::string outputLabel
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
std::map< std::string, std::map< std::string, double > > m_scalarInputs
lwt::GraphConfig m_nnConfig
StatusCode configure(const std::string &json)
std::string m_outputLabel
double evaluate(std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
std::unique_ptr< lwt::LightweightGraph > m_nn
singleton-like access to IMessageSvc via open function and helper
IMessageSvc * getMessageSvc(bool quiet=false)