ATLAS Offline Software
Loading...
Searching...
No Matches
GlobalLargeRDNNCalibration.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
5/* ***********************************************************************************\
6 * *
7 * Name: GlobalLargeRDNNCalibration *
8 * Purpose: Perform the DNN JES and JMS step of the large-R jets' calibration *
9 * *
10 * # Date Comments By *
11 * -- -------- -------------------------- ------------------------------------------ *
12 * 1 31/01/23 First Version G. Albouy, P.-A. Delsart *
13\*************************************************************************************/
14
15
16#ifndef JetCalibTools_GlobalLargeRDNNCalibration_H
17#define JetCalibTools_GlobalLargeRDNNCalibration_H
18
19
20// Other packages includes
21#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h" //Ort::Session
22
23// Local includes
25
26#include <string>
27#include <vector>
28
29class EventInfo;
30class TEnv;
31
33
34public:
35 // Constructor/destructor/init
40
45 GlobalLargeRDNNCalibration(const std::string& name);
46
55 GlobalLargeRDNNCalibration(const std::string& name, TEnv * config, const TString& calibArea, bool dev);
56
61
66 virtual StatusCode initialize() override;
67
68
71 struct VarRetriever;
72
73 protected:
74 // @brief Calibrates the jet, and decorates it with the calibration using the name "JetGNNCScaleMomentum"
75 // @param jet_reco The jet
76 // @param jetEventInfo A set of information about the event and jet
77 virtual StatusCode calibrate(xAOD::Jet& jet, JetEventInfo&) const override;
78
79 private:
80
86 std::vector<float> getJetFeatures( xAOD::Jet& jet_reco, JetEventInfo& jetEventInfo) const;
87
88 std::vector<TString> m_NNInputs;
89 std::vector<double> m_eScales;
90 std::vector<double> m_NormOffsets;
91 std::vector<double> m_NormScales;
92 std::string m_modelFileName;
93
94 std::vector<VarRetriever*> m_varretrievers;
95
96 std::unique_ptr< Ort::Session > m_session;
97 std::vector<int64_t> m_input_node_dims;
98 std::vector<const char*> m_input_node_names;
99 std::vector<int64_t> m_output_node_dims;
100 std::vector<const char*> m_output_node_names;
101
102 TEnv * m_config{};
103 std::string m_calibArea;
104 bool m_devMode{};
105
106
107}; // Class GlobalLargeRDNNCalibration
108
109
110#endif
virtual StatusCode calibrate(xAOD::Jet &jet, JetEventInfo &) const override
std::vector< float > getJetFeatures(xAOD::Jet &jet_reco, JetEventInfo &jetEventInfo) const
Returns a vector of input features for the NN.
std::vector< const char * > m_output_node_names
std::vector< int64_t > m_output_node_dims
std::vector< VarRetriever * > m_varretrievers
std::unique_ptr< Ort::Session > m_session
std::vector< int64_t > m_input_node_dims
virtual StatusCode initialize() override
Returns the charged fraction of a jet.
virtual ~GlobalLargeRDNNCalibration()
The destructor.
std::vector< const char * > m_input_node_names
JetCalibrationStep(const char *name="JetCalibrationStep")
Jet_v1 Jet
Definition of the current "jet version".
VarRetriever is a generic class to access Jet and/or JetEventInfo variables.