ATLAS Offline Software
Loading...
Searching...
No Matches
TRTPIDNN.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef INDETTRTPIDNN_H
5#define INDETTRTPIDNN_H
6
8// TRTPIDNN.h, (c) ATLAS Detector software
10
11/****************************************************************************************\
12
13 Class to wrap the lwtnn instance of the TRT PID NN. It is instantiated in PIDNNCondAlg.
14
15 Author: Christian Grefe (christian.grefe@cern.ch)
16
17\****************************************************************************************/
18#include "GaudiKernel/StatusCode.h"
21#include "lwtnn/LightweightGraph.hh"
22#include <map>
23#include <memory>
24#include <string>
25#include <vector>
26
27namespace InDet {
28 class TRTPIDNN {
29 public:
30 TRTPIDNN()=default;
31 virtual ~TRTPIDNN()=default;
32
33 const std::string& getDefaultOutputNode() const {
34 return m_outputNode;
35 }
36
37 const std::string& getDefaultOutputLabel() const {
38 return m_outputLabel;
39 }
40
41 // get the structure of the scalar inputs to the NN
42 const std::map<std::string, std::map<std::string, double>>& getScalarInputs() const {
43 return m_scalarInputs;
44 }
45
46 // get the structure of the vector inputs to the NN
47 const std::map<std::string, std::map<std::string, std::vector<double>>>& getVectorInputs() const {
48 return m_vectorInputs;
49 }
50
51 // calculate NN response for default output node and label
52 double evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
53 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
54 return evaluate(scalarInputs, vectorInputs, m_outputNode, m_outputLabel);
55 }
56
57 // calculate NN response
58 double evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
59 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
60 const std::string& outputNode, const std::string& outputLabel) const;
61
62 // set up the NN
63 StatusCode configure(const std::string& json);
64
65 private:
66 std::unique_ptr<lwt::LightweightGraph> m_nn; // the NN
67 lwt::GraphConfig m_nnConfig; // cofiguration of the NN
68 std::map<std::string, std::map<std::string, double>> m_scalarInputs; // template for the structure of the scalar inputs to the NN
69 std::map<std::string, std::map<std::string, std::vector<double>>> m_vectorInputs; // template for the structure of the vector inputs to the NN
70 std::string m_outputNode; // name of the output node of the NN
71 std::string m_outputLabel; // name of the output label of the NN
72};
73}
74CLASS_DEF(InDet::TRTPIDNN,341715853,1)
75CONDCONT_DEF(InDet::TRTPIDNN,710491600);
76
77#endif
Hold mappings of ranges to condition objects.
#define CONDCONT_DEF(...)
Definition CondCont.h:1413
macros to associate a CLID to a type
#define CLASS_DEF(NAME, CID, VERSION)
associate a clid and a version to a type eg
nlohmann::json json
const std::string outputLabel
const std::string & getDefaultOutputLabel() const
Definition TRTPIDNN.h:37
std::map< std::string, std::map< std::string, std::vector< double > > > m_vectorInputs
Definition TRTPIDNN.h:69
std::map< std::string, std::map< std::string, double > > m_scalarInputs
Definition TRTPIDNN.h:68
const std::string & getDefaultOutputNode() const
Definition TRTPIDNN.h:33
const std::map< std::string, std::map< std::string, double > > & getScalarInputs() const
Definition TRTPIDNN.h:42
lwt::GraphConfig m_nnConfig
Definition TRTPIDNN.h:67
virtual ~TRTPIDNN()=default
TRTPIDNN()=default
StatusCode configure(const std::string &json)
Definition TRTPIDNN.cxx:35
std::string m_outputLabel
Definition TRTPIDNN.h:71
double evaluate(std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
Definition TRTPIDNN.h:52
std::unique_ptr< lwt::LightweightGraph > m_nn
Definition TRTPIDNN.h:66
const std::map< std::string, std::map< std::string, std::vector< double > > > & getVectorInputs() const
Definition TRTPIDNN.h:47
std::string m_outputNode
Definition TRTPIDNN.h:70
Primary Vertex Finder.