ATLAS Offline Software
Loading...
Searching...
No Matches
RNNTool.h
Go to the documentation of this file.
1// This is -*- c++ -*-
2
3/*
4 Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
5*/
6
7#ifndef ATH_RNNTOOL_H
8#define ATH_RNNTOOL_H
9
10/**********************************************************************************
11 * @Package: PhysicsAnpProd
12 * @Class : RNNTool
13 * @Author : Rustem Ospanov
14 *
15 * @Brief : Tool to access RNN from lwtnn library
16 *
17 **********************************************************************************/
18
19// Tools
21
22// Local
23#include "IRNNTool.h"
24
25// Athena
27
28// External
29#include "lwtnn/LightweightGraph.hh"
30
31namespace Prompt
32{
33 // Forward declarations
34
35 class VarHolder;
36
37 // Main body
38
39 class RNNTool : public AthAlgTool, public IRNNTool
40 {
41 /*
42 RNN tool is based on the lwtnn package.
43 1. It will take the inputs VarHolder objects and convert them into lwt::VectorMap format.
44 2. lwtnn will use lwt::VectorMap format inputs with a given RNN weight json file to predict RNN scores.
45 3. Then this tool will return std::map<std::string, double> object, which contains the name string of the RNN scores and their predictions.
46
47 */
48 public:
49
50 RNNTool(const std::string &name,
51 const std::string &type,
52 const IInterface *parent);
53
54 virtual StatusCode initialize() override;
55
56 virtual std::map<std::string, double> computeRNNOutput(
57 const std::vector<Prompt::VarHolder> &tracks
58 ) override;
59
60 virtual std::set<std::string> getOutputLabels() const override { return m_outputLabels; }
61
62 private:
63
64 void AddVariable(
65 const std::vector<Prompt::VarHolder> &tracks, unsigned var, std::vector<double> &values
66 );
67
68 private:
69
70 std::string m_configPathRNN;
71 std::string m_configRNNVersion;
73
76
77 std::set<std::string> m_outputLabels;
78
79 std::unique_ptr<lwt::LightweightGraph> m_graph;
80 };
81}
82
83#endif
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
RNNTool(const std::string &name, const std::string &type, const IInterface *parent)
Definition RNNTool.cxx:19
virtual std::map< std::string, double > computeRNNOutput(const std::vector< Prompt::VarHolder > &tracks) override
Definition RNNTool.cxx:97
unsigned m_inputSequenceSize
Definition RNNTool.h:75
virtual StatusCode initialize() override
Definition RNNTool.cxx:33
std::string m_inputSequenceName
Definition RNNTool.h:74
void AddVariable(const std::vector< Prompt::VarHolder > &tracks, unsigned var, std::vector< double > &values)
Definition RNNTool.cxx:142
std::string m_configPathRNN
Definition RNNTool.h:70
std::set< std::string > m_outputLabels
Definition RNNTool.h:77
virtual std::set< std::string > getOutputLabels() const override
Definition RNNTool.h:60
std::string m_configRNNJsonFile
Definition RNNTool.h:72
std::string m_configRNNVersion
Definition RNNTool.h:71
std::unique_ptr< lwt::LightweightGraph > m_graph
Definition RNNTool.h:79