10#include "lwtnn/parse_json.hh"
11#include "lwtnn/NanReplacer.hh"
22 declareInterface<Prompt::IRNNTool>(
this);
24 declareProperty(
"configPathRNN",
m_configPathRNN,
"Path of the local RNN json file you want o study/test, it will override the PathResolverFindCalibFile file");
38 std::string fullPathToFile;
50 ATH_MSG_INFO(
"initialize RNNTool - ConfigPathRNN: \"" << fullPathToFile);
55 std::ifstream input_stream(fullPathToFile);
57 lwt::GraphConfig graph_config = lwt::parse_json_graph(input_stream);
59 m_graph = std::make_unique<lwt::LightweightGraph>(graph_config);
61 for(
const auto &o: graph_config.outputs) {
62 ATH_MSG_DEBUG(
" output name: " << o.first <<
", node_index=" << o.second.node_index);
64 for(
const auto &l: o.second.labels) {
73 ATH_MSG_DEBUG(
"Number of input sequences: " << graph_config.input_sequences.size());
75 for(
const auto &n: graph_config.input_sequences) {
78 for(
const lwt::Input &v: n.variables) {
83 ATH_MSG_DEBUG(
"Number of inputs: " << graph_config.inputs.size());
85 for(
const auto &n: graph_config.inputs) {
88 for(
const lwt::Input &v: n.variables) {
93 return StatusCode::SUCCESS;
101 lwt::LightweightGraph::NodeMap nodes;
102 lwt::LightweightGraph::SeqNodeMap seqs;
113 if(vmap.size() != 6) {
114 ATH_MSG_WARNING(
"RNNTool::computeRNNOutput - incomplete variables: return empty result");
115 return lwt::ValueMap();
120 for(
const lwt::VectorMap::value_type &v: vmap) {
121 nwid = std::max<unsigned>(v.first.size(), nwid);
124 for(
const lwt::VectorMap::value_type &v: vmap) {
125 ATH_MSG_DEBUG(std::setw(nwid+1) << std::left << v.first <<
" ");
127 for(
const double d: v.second) {
134 for(
const lwt::ValueMap::value_type &v:
result) {
135 ATH_MSG_DEBUG(v.first <<
" score=" << std::setprecision(10) << v.second);
144 std::vector<double> &values)
151 for(
unsigned i = 0; i < nvar; ++i) {
154 if(i < tracks.size()) {
157 if(!track.getVar(var, value)) {
162 values.push_back(value);
#define ATH_MSG_WARNING(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)