10 #include "lwtnn/parse_json.hh"
11 #include "lwtnn/NanReplacer.hh"
22 declareInterface<Prompt::IRNNTool>(
this);
24 declareProperty(
"configPathRNN",
m_configPathRNN,
"Path of the local RNN json file you want o study/test, it will override the PathResolverFindCalibFile file");
38 std::string fullPathToFile;
40 if(!m_configPathRNN.empty()) {
41 ATH_MSG_INFO(
"Override PathResolver to this path: " << m_configPathRNN);
42 fullPathToFile = m_configPathRNN;
46 + m_configRNNVersion +
"/"
47 + m_configRNNJsonFile);
50 ATH_MSG_INFO(
"initialize RNNTool - ConfigPathRNN: \"" << fullPathToFile);
55 std::ifstream input_stream(fullPathToFile);
59 m_graph = std::make_unique<lwt::LightweightGraph>(graph_config);
61 for(
const auto &o: graph_config.outputs) {
62 ATH_MSG_DEBUG(
" output name: " << o.first <<
", node_index=" << o.second.node_index);
64 for(
const auto &
l: o.second.labels) {
67 if(!m_outputLabels.insert(
l).second) {
73 ATH_MSG_DEBUG(
"Number of input sequences: " << graph_config.input_sequences.size());
75 for(
const auto &
n: graph_config.input_sequences) {
83 ATH_MSG_DEBUG(
"Number of inputs: " << graph_config.inputs.size());
85 for(
const auto &
n: graph_config.inputs) {
93 return StatusCode::SUCCESS;
108 AddVariable(tracks,
Def::Z0Sin, vmap[
"m_cone_tracks_Z0Sin"]);
109 AddVariable(tracks,
Def::D0Sig, vmap[
"m_cone_tracks_D0Sig"]);
110 AddVariable(tracks,
Def::TrackJetDR, vmap[
"m_cone_tracks_DRTrackJet"]);
113 if(vmap.size() != 6) {
114 ATH_MSG_WARNING(
"RNNTool::computeRNNOutput - incomplete variables: return empty result");
120 for(
const lwt::VectorMap::value_type &
v: vmap) {
121 nwid = std::max<unsigned>(
v.first.size(), nwid);
124 for(
const lwt::VectorMap::value_type &
v: vmap) {
127 for(
const double d:
v.second) {
134 for(
const lwt::ValueMap::value_type &
v:
result) {
135 ATH_MSG_DEBUG(
v.first <<
" score=" << std::setprecision(10) <<
v.second);
144 std::vector<double> &
values)
149 const unsigned nvar = std::min<unsigned>(tracks.size(), m_inputSequenceSize);
151 for(
unsigned i = 0;
i < nvar; ++
i) {
154 if(
i < tracks.size()) {