ATLAS Offline Software
Loading...
Searching...
No Matches
GlobalLargeRDNNCalibration.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
5// System includes
7
8
9#ifdef XAOD_STANDALONE
11#endif
12
16
17#include <TEnv.h>
18#include <tuple>
19#include <cmath>
20#include <map>
21#include <algorithm> //count_if
22
23
24namespace{
25 // Redefine some functions from the package OnnxRuntimeUtils which is not (yet) available in AnalysisBase
26 // Set up the ONNX Runtime session
27 std::unique_ptr< Ort::Session > CreateORTSession(const std::string& modelFile){
28 Ort::SessionOptions sessionOptions;
29 sessionOptions.SetIntraOpNumThreads( 1 );
30 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
31
32 // Set the ONNX service name depending on the actual analysis release
33 std::string serviceName;
34 #ifdef XAOD_STANDALONE
35 using namespace asg::msgUserCode;
36 ANA_MSG_WARNING("If running DNN calibration in AnalysisBase: necessary to instantiate the ONNX service AthOnnx::OnnxRuntimeSvc with name OnnxRuntimeSvc");
37 ATH_MSG_WARNING("Either in C++ config (see exemple in JetCalibTools_Example.cxx)");
38 ATH_MSG_WARNING("Or in python config with");
39 ATH_MSG_WARNING(" from AnaAlgorithm.DualUseConfig import createService");
40 ATH_MSG_WARNING(" onnxSvc = createService('AthOnnx::OnnxRuntimeSvc', 'OnnxRuntimeSvc', myAlgSequence)");
41 serviceName = "OnnxRuntimeSvc";
42 #else
43 serviceName = "AthOnnx::OnnxRuntimeSvc";
44 #endif
45
46 ServiceHandle< AthOnnx::IOnnxRuntimeSvc > svc(serviceName, "AthOnnx::OnnxRuntimeSvc");
47
48 return std::make_unique<Ort::Session>( svc->env(),
49 modelFile.c_str(),
50 sessionOptions );
51 }
52
53 // Get dimensions and names of the input nodes
54 std::tuple<std::vector<int64_t>, std::vector<const char*> > GetInputNodeInfo(const std::unique_ptr< Ort::Session >& session){
55 std::vector<int64_t> input_node_dims;
56 size_t num_input_nodes = session->GetInputCount();
57 std::vector<const char*> input_node_names(num_input_nodes);
58 Ort::AllocatorWithDefaultOptions allocator;
59 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
60
61 char* input_name = session->GetInputNameAllocated(i, allocator).release();
62 input_node_names[i] = input_name;
63 Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
64 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
65
66 input_node_dims = tensor_info.GetShape();
67 }
68 return std::make_tuple(input_node_dims, input_node_names);
69 }
70
71 // Get dimensions and names of the output nodes
72 std::tuple<std::vector<int64_t>, std::vector<const char*> > GetOutputNodeInfo(const std::unique_ptr< Ort::Session >& session){
73 std::vector<int64_t> output_node_dims;
74 size_t num_output_nodes = session->GetOutputCount();
75 std::vector<const char*> output_node_names(num_output_nodes);
76 Ort::AllocatorWithDefaultOptions allocator;
77
78 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
79 char* output_name = session->GetOutputNameAllocated(i, allocator).release();
80 output_node_names[i] = output_name;
81
82 Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
83 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
84
85 output_node_dims = tensor_info.GetShape();
86 }
87 return std::make_tuple(output_node_dims, output_node_names);
88 }
89}
90
91
95
97 virtual float value(const xAOD::Jet& jet, JetEventInfo& jetInfo, double eScale) = 0;
98 virtual ~VarRetriever()= default;
99};
100
101namespace {
102
104 struct VarAccessorRetriever : public GlobalLargeRDNNCalibration::VarRetriever {
105 VarAccessorRetriever(const std::string &n): m_acc(n) {}
106
107 virtual float value(const xAOD::Jet& jet, JetEventInfo&, double eScale) {
108 return m_acc(jet) * eScale;
109 }
110
112 };
113
116 struct RatioAccessorRetriever : public GlobalLargeRDNNCalibration::VarRetriever {
117 RatioAccessorRetriever(): m_accTau1("Tau1_wta"),
118 m_accTau2("Tau2_wta"),
119 m_accTau3("Tau3_wta"),
120 m_accECF1("ECF1"),
121 m_accECF2("ECF2"),
122 m_accECF3("ECF3") {}
123
124 virtual float value(const xAOD::Jet& jet, JetEventInfo&, double eScale) = 0;
125
126 SG::AuxElement::ConstAccessor<float> m_accTau1;
127 SG::AuxElement::ConstAccessor<float> m_accTau2;
128 SG::AuxElement::ConstAccessor<float> m_accTau3;
129 SG::AuxElement::ConstAccessor<float> m_accECF1;
130 SG::AuxElement::ConstAccessor<float> m_accECF2;
131 SG::AuxElement::ConstAccessor<float> m_accECF3;
132 };
133
135 #define DEF_RETRIEVER0(cname, expr ) struct Var_##cname : public GlobalLargeRDNNCalibration::VarRetriever { float value(const xAOD::Jet& jet, JetEventInfo& , double eScale ) { return expr ; } }
136 #define DEF_RETRIEVER1(cname, expr ) struct Var_##cname : public GlobalLargeRDNNCalibration::VarRetriever { float value(const xAOD::Jet& , JetEventInfo& jetInfo, double eScale ) { return expr ; } }
137 #define DEF_RATIO_RETRIEVER(cname, expr ) struct Ratio_##cname : public RatioAccessorRetriever { float value(const xAOD::Jet& jet, JetEventInfo& , double eScale ) { return expr ; } }
138
139 // Std jet variables
140 DEF_RETRIEVER0( eta, jet.eta()*eScale ) ;
141 DEF_RETRIEVER0( rapidity, jet.rapidity()*eScale ) ;
142 DEF_RETRIEVER0( log_e, log(jet.e()*eScale) ) ;
143 DEF_RETRIEVER0( log_m, log(jet.m()*eScale) ) ;
144 DEF_RETRIEVER0( m, jet.m()*eScale ) ;
145
146 // Ratio variables -- default values consistent with DNN training
147 DEF_RATIO_RETRIEVER( Tau21_wta, m_accTau1(jet) > 1e-8 ? eScale * m_accTau2(jet) / m_accTau1(jet) : -0.1);
148 DEF_RATIO_RETRIEVER( Tau32_wta, m_accTau2(jet) > 1e-8 ? eScale * m_accTau3(jet) / m_accTau2(jet) : -0.1);
149 DEF_RATIO_RETRIEVER( C2, m_accECF2(jet) > 1e-8 ? eScale * m_accECF3(jet) * m_accECF1(jet) / pow(m_accECF2(jet), 2.0) : -0.1);
150 DEF_RATIO_RETRIEVER( D2, m_accECF2(jet) > 1e-8 ? eScale * m_accECF3(jet) * pow(m_accECF1(jet), 3.0) / pow(m_accECF2(jet), 3.0) : -0.1);
151
152 // Std pile-up info
153 DEF_RETRIEVER1( mu, jetInfo.mu()*eScale );
154 DEF_RETRIEVER1( NPV, jetInfo.NPV()*eScale );
155
156 #undef DEF_RETRIEVER
157
159 GlobalLargeRDNNCalibration::VarRetriever* buildVarRetriever(const std::string & name){
160 // create a map of known specialized VarRetriever.
161 // it's just a map "name" <-> function returning a Var_xyz()
162 static const std::map<std::string, std::function<GlobalLargeRDNNCalibration::VarRetriever*()> > knownVar{
163 {"eta", [](){return new Var_eta();} },
164 {"rapidity", [](){return new Var_rapidity();} },
165 {"log_e", [](){return new Var_log_e();} },
166 {"log_m", [](){return new Var_log_m();} },
167 {"Tau21_wta", [](){return new Ratio_Tau21_wta();} },
168 {"Tau32_wta", [](){return new Ratio_Tau32_wta();} },
169 {"C2", [](){return new Ratio_C2();} },
170 {"D2", [](){return new Ratio_D2();} },
171 {"mu", [](){return new Var_mu();} },
172 {"NPV", [](){return new Var_NPV();} },
173 };
174
175 auto it = knownVar.find(name);
176 // if name is not a known variable, assume it's a jet attribute, so return a generic VarAccessorRetriever
177 if( it == knownVar.end() ) return new VarAccessorRetriever(name);
178 // else we just return an instance of a known VarRetriever class
179 // (it->second is the function : we call it to obtain a new pointer)
180 return it->second();
181 }
182
183}
184
187 : JetCalibrationStep::JetCalibrationStep("GlobalLargeRDNNCalibration/GlobalLargeRDNNCalibration"),
188 m_config(nullptr), m_calibArea("")
189{
190}
191
193 : JetCalibrationStep::JetCalibrationStep(name.c_str()),
194 m_config(nullptr), m_calibArea("")
195{
196}
197
198GlobalLargeRDNNCalibration::GlobalLargeRDNNCalibration(const std::string& name, TEnv * config, const TString& calibArea, bool dev)
199 : JetCalibrationStep::JetCalibrationStep( name.c_str() ),
200 m_config(config), m_calibArea(calibArea), m_devMode(dev)
201{
202}
203
208
209// Initialize
211 ATH_MSG_DEBUG("Initializing tool");
212 if ( !m_config ) { ATH_MSG_FATAL("Config file not specified. Aborting."); return StatusCode::FAILURE; }
213
214 // Get list of input features
215 m_NNInputs = JetCalibUtils::Vectorize( m_config->GetValue("DNNC.Inputs","") );
216 // Now build a VarRetriever for each of the input features
217 m_varretrievers.resize(m_NNInputs.size());
218 ATH_MSG_DEBUG("DNN inputs");
219 for (long unsigned int i=0;i<m_NNInputs.size();i++) {
220 m_varretrievers[i] = buildVarRetriever( m_NNInputs[i].Data() );
221 ATH_MSG_DEBUG(" " << m_NNInputs[i]);
222 }
223
224 // Get normalization constants for input features
225 m_eScales = JetCalibUtils::VectorizeD( m_config->GetValue("DNNC.EScales","") );
226 m_NormOffsets = JetCalibUtils::VectorizeD( m_config->GetValue("DNNC.NormOffsets","") );
227 m_NormScales = JetCalibUtils::VectorizeD( m_config->GetValue("DNNC.NormScales","") );
228
229 if (m_eScales.size()!=m_NNInputs.size() || m_NormOffsets.size()!=m_NNInputs.size() || m_NormScales.size()!=m_NNInputs.size()) {
230 ATH_MSG_FATAL("Misconfiguration of config file : not same number of offset/scale parameters and number of features. Will exit");
231 return StatusCode::FAILURE;
232 }
233
234 if( msgLvl(MSG::DEBUG) ){
235 ATH_MSG_DEBUG("m_NormOffsets size : " << m_NormOffsets.size());
236 ATH_MSG_DEBUG("m_NormOffsets");
237 for (long unsigned int i=0;i<m_NormOffsets.size();i++) {
238 ATH_MSG_DEBUG(" " << m_NormOffsets[i]);
239 }
240 ATH_MSG_DEBUG("m_NormScales size : " << m_NormScales.size());
241 ATH_MSG_DEBUG("m_NormScales");
242 for (long unsigned int i=0;i<m_NormScales.size();i++) {
243 ATH_MSG_DEBUG(" " << m_NormScales[i]);
244 }
245 }
246
247 // Get DNN config file
248 m_modelFileName = m_config->GetValue("DNNC.ONNXInput","");
249 std::string modelPath = "";
250 if (m_devMode) {
251 modelPath="JetCalibTools/"+m_modelFileName;
252 } else {
253 modelPath="JetCalibTools/"+m_calibArea+"CalibrationConfigs/"+m_modelFileName;
254 }
255 const std::string fullModelPath = PathResolverFindCalibFile( modelPath ); // Full path
256 ATH_MSG_INFO("Using ONNX model : " << m_modelFileName);
257 ATH_MSG_INFO("resolved in: " << fullModelPath);
258
259 // Set up the ONNX Runtime session.
260 m_session = CreateORTSession(fullModelPath);
261 ATH_MSG_DEBUG( "ONNX Runtime session succesfully created" );
262
263
264 /************************** Input Nodes *****************************/
265 /*********************************************************************/
266 std::tuple<std::vector<int64_t>, std::vector<const char*> > inputInfo = GetInputNodeInfo(m_session);
267 m_input_node_dims = std::get<0>(inputInfo);
268 m_input_node_names = std::get<1>(inputInfo);
269
270 if( msgLvl(MSG::DEBUG) ){
271 for( std::size_t i = 0; i < m_input_node_names.size(); i++ ) {
272 // print input node names
273 ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<m_input_node_names[i]);
274
275 // print input shapes/dims
276 ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<m_input_node_dims.size());
277 for (std::size_t j = 0; j < m_input_node_dims.size(); j++){
278 ATH_MSG_DEBUG("Input "<<i<<" : dim "<<j<<"= "<<m_input_node_dims[j]);
279 }
280 }
281 }
282
283 /************************** Output Nodes *****************************/
284 /*********************************************************************/
285 std::tuple<std::vector<int64_t>, std::vector<const char*> > outputInfo = GetOutputNodeInfo(m_session);
286 m_output_node_dims = std::get<0>(outputInfo);
287 m_output_node_names = std::get<1>(outputInfo);
288
289 if( msgLvl(MSG::DEBUG) ){
290 for( std::size_t i = 0; i < m_output_node_names.size(); i++ ) {
291 // print input node names
292 ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<m_output_node_names[i]);
293
294 // print input shapes/dims
295 ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<m_output_node_dims.size());
296 for (std::size_t j = 0; j < m_output_node_dims.size(); j++){
297 ATH_MSG_DEBUG("Output "<<i<<" : dim "<<j<<"= "<<m_output_node_dims[j]);
298 }
299 }
300 }
301
302 /**************************************************************************************
303 * m_input_node_dims[0] = -1; -1 needs to be replaced by the batch size; for no batch --> 1
304 * m_input_node_dims[1] should be equal to m_NNInputs.size()
305 ****************************************************************************************/
306 m_input_node_dims[0] = 1;
307 m_output_node_dims[0] = 1;
308
309 if (m_NNInputs.size()!=(long unsigned int)m_input_node_dims[1]) {
310 ATH_MSG_FATAL("DNN input features not the same size as in config, will exit");
311 return StatusCode::FAILURE;
312 }
313
314 // Set jet starting scale
315 m_jetStartScale = "JetConstitScaleMomentum";
316
317 return StatusCode::SUCCESS;
318}
319
320
321
323
324 // Set jet initial scale
325 xAOD::JetFourMom_t jetStartP4;
327 jetStartP4 = jet.jetP4();
328
329 // Don't apply calibration for jets with negative or null mass or for one constituent jets
330 if(jet.m()<=0 || jet.numConstituents()==1){
331 jet.setAttribute<xAOD::JetFourMom_t>("JetDNNCScaleMomentum",jetStartP4);
332 return StatusCode::SUCCESS;
333 }
334
335 // Get input features normalized for jet
336 std::vector<float> input_tensor_values = getJetFeatures(jet, jetEventInfo);
337 if( msgLvl(MSG::DEBUG) ){
338 ATH_MSG_DEBUG("Input tensor values : ");
339 for (long unsigned int i=0;i<input_tensor_values.size();i++) ATH_MSG_DEBUG(" " << input_tensor_values[i]);
340 }
341
342 // Check for nan or +/- inf values
343 int nNan = std::count_if(input_tensor_values.begin(), input_tensor_values.end(), [](float f){return std::isnan(f) || std::isinf(f);});
344 if (nNan>0) {
345 ATH_MSG_WARNING("Encountered Nan or inf value in input features, will not apply calibration");
346 jet.setAttribute<xAOD::JetFourMom_t>("JetDNNCScaleMomentum",jetStartP4);
347 return StatusCode::SUCCESS;
348 }
349
350 // Convert input_tensor_values array to onnx-compatible tensor
351 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
352 Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
353 input_tensor_values.data(),
354 input_tensor_values.size(),
355 m_input_node_dims.data(),
356 m_input_node_dims.size());
357
358 // Make sure we get the same input values in tensor
359 std::vector<float> vec(input_tensor.GetTensorMutableData<float>(), input_tensor.GetTensorMutableData<float>() + m_input_node_dims[1]);
360 if (vec!=input_tensor_values) {
361 ATH_MSG_WARNING("Input tensor after convertion to Ort tensor is not the same as the input vector, will not apply calibration");
362 jet.setAttribute<xAOD::JetFourMom_t>("JetDNNCScaleMomentum",jetStartP4);
363 return StatusCode::SUCCESS;
364 }
365
366 // Run inference on input_tensor
367 Ort::Session& session ATLAS_THREAD_SAFE = *m_session;
368 auto output_tensor = session.Run( Ort::RunOptions{nullptr},
369 m_input_node_names.data(),
370 &input_tensor,
371 m_input_node_names.size(),
372 m_output_node_names.data(),
373 m_output_node_names.size());
374 if (!output_tensor.front().IsTensor() || output_tensor.size() != m_output_node_names.size() || output_tensor.front().GetTensorTypeAndShapeInfo().GetShape() != m_output_node_dims) {
375 ATH_MSG_WARNING("Output tensor does not have the same size as output layer, will not apply calibration");
376 jet.setAttribute<xAOD::JetFourMom_t>("JetDNNCScaleMomentum",jetStartP4);
377 return StatusCode::SUCCESS;
378 }
379
380 // Get pointer to output tensor float values
381 float* outputE = output_tensor.at(0).GetTensorMutableData<float>();
382 float* outputM = output_tensor.at(1).GetTensorMutableData<float>();
383
384 // Get predicted calibration factors
385 float predRespE = outputE[0]; // first element is predicted response
386 float predRespM = outputM[0];
387
388 // Print the output predictions for E/M
389 ATH_MSG_DEBUG("Output E : " << predRespE);
390 ATH_MSG_DEBUG("Output M : " << predRespM);
391
392 if (predRespE==0 || predRespM==0) {
393 ATH_MSG_WARNING("Predictions give 0 values, will not apply calibration");
394 jet.setAttribute<xAOD::JetFourMom_t>("JetDNNCScaleMomentum",jetStartP4);
395 return StatusCode::SUCCESS;
396 }
397
398 // Apply calibration to jet p4
399 float calibE = jetStartP4.e() / predRespE;
400
401 // For mass only apply calibration if m>40 GeV
402 float calibM = jetStartP4.mass();
403 if ( calibM > 40000 ) {
404 calibM /= predRespM;
405 }
406
407 // Propagate energy and mass calibration to jet pT
408 float calibpT = std::sqrt( calibE*calibE - calibM*calibM )/std::cosh( jetStartP4.eta() );
409
410 // Build calibrated jet p4
411 TLorentzVector TLVjet;
412 TLVjet.SetPtEtaPhiM( calibpT, jetStartP4.eta(), jetStartP4.phi(), calibM );
413 xAOD::JetFourMom_t calibP4;
414 calibP4.SetPxPyPzE( TLVjet.Px(), TLVjet.Py(), TLVjet.Pz(), TLVjet.E() );
415
416 // Transfer calibrated jet properties to the Jet object
417 jet.setAttribute<xAOD::JetFourMom_t>("JetDNNCScaleMomentum",calibP4);
418 jet.setJetP4( calibP4 );
419
420 return StatusCode::SUCCESS;
421
422}
423
424
425
426std::vector<float> GlobalLargeRDNNCalibration::getJetFeatures( xAOD::Jet& jet_reco, JetEventInfo& jetEventInfo) const {
427 // Init input tensor
428 std::vector<float> input_tensor_values(m_NNInputs.size());
429
430 // Retrieve all input variables from the jet and/or jetEventInfo using our VarRetriever collection:
431 for(size_t i=0;i<input_tensor_values.size();i++){
432 float v = m_varretrievers[i]->value(jet_reco, jetEventInfo, m_eScales[i]);
433 // and perform normalisation :
434 input_tensor_values[i] = v*m_NormScales[i] + m_NormOffsets[i];
435 }
436
437 return input_tensor_values;
438}
Scalar eta() const
pseudorapidity method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
@ Data
Definition BaseObject.h:11
std::vector< size_t > vec
macros for messaging and checking status codes
#define ANA_MSG_WARNING(xmsg)
Macro printing warning messages.
#define DEF_RETRIEVER0(cname, expr)
Define shortcuts macro to declare specialized VarRetriever class in one line.
#define DEF_RETRIEVER1(cname, expr)
#define DEF_RATIO_RETRIEVER(cname, expr)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
constexpr int pow(int base, int exp) noexcept
std::atomic_flag m_initialized ATLAS_THREAD_SAFE
Messaging initialized (initMessaging)
virtual StatusCode calibrate(xAOD::Jet &jet, JetEventInfo &) const override
std::vector< float > getJetFeatures(xAOD::Jet &jet_reco, JetEventInfo &jetEventInfo) const
Returns a vector of input features for the NN.
std::vector< const char * > m_output_node_names
std::vector< int64_t > m_output_node_dims
std::vector< VarRetriever * > m_varretrievers
std::unique_ptr< Ort::Session > m_session
std::vector< int64_t > m_input_node_dims
virtual StatusCode initialize() override
Returns the charged fraction of a jet.
virtual ~GlobalLargeRDNNCalibration()
The destructor.
std::vector< const char * > m_input_node_names
virtual StatusCode setStartP4(xAOD::Jet &jet) const
JetCalibrationStep(const char *name="JetCalibrationStep")
SG::ConstAccessor< T, ALLOC > ConstAccessor
Definition AuxElement.h:569
bool msgLvl(const MSG::Level lvl) const
Test the output level of the object.
AthROOTErrorHandlerSvc * svc
StrV Vectorize(const TString &str, const TString &sep=" ")
VecD VectorizeD(const TString &str, const TString &sep=" ")
Jet_v1 Jet
Definition of the current "jet version".
ROOT::Math::LorentzVector< ROOT::Math::PtEtaPhiM4D< double > > JetFourMom_t
Base 4 Momentum type for Jet.
Definition JetTypes.h:17
VarRetriever is a generic class to access Jet and/or JetEventInfo variables.
virtual float value(const xAOD::Jet &jet, JetEventInfo &jetInfo, double eScale)=0
the value of the variable to be retrieved from the jet and/or JetEventInfo