ATLAS Offline Software
Loading...
Searching...
No Matches
TrigADComboHypoTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
7#include "xAODJet/Jet.h"
9#include "xAODMuon/Muon.h"
11#include "xAODEgamma/Electron.h"
13#include "xAODEgamma/Photon.h"
15#include "xAODTau/TauJet.h"
19#include <Math/Vector2D.h>
20#include <Math/Vector2Dfwd.h>
21
22#include <cmath>
23
24TrigADComboHypoTool::TrigADComboHypoTool(const std::string& type, const std::string& name, const IInterface* parent): ComboHypoToolBase(type, name, parent) {}
25
27
28 if (!m_monTool.empty()) ATH_CHECK(m_monTool.retrieve());
29
31 if (m_adScoreKey.key() == "Undefined") {
32 ATH_MSG_ERROR("AD score key name is undefined" );
33 return StatusCode::FAILURE;
34 }
35
36 ATH_CHECK( m_svc.retrieve() );
37 std::string model_file_name = PathResolverFindCalibFile(m_modelFileName);
38
39 if (m_modelFileName.empty() || model_file_name.empty()) {
40 ATH_MSG_ERROR("Could not find the requested ONNX model file: " << m_modelFileName);
41 ATH_MSG_ERROR("Please make sure it exists in the ATLAS calibration area (https://atlas-groupdata.web.cern.ch/atlas-groupdata/), and provide a model file name relative to the root of the calibration area.");
42
43 return StatusCode::FAILURE;
44 }
45
46 // initialise session
47 Ort::SessionOptions session_options;
48 Ort::AllocatorWithDefaultOptions allocator;
49 session_options.SetIntraOpNumThreads(1);
50 session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
51
52 m_session = std::make_unique<Ort::Session>(m_svc->env(), model_file_name.c_str(), session_options);
53
54 ATH_MSG_INFO("Created ONNX runtime session with model " << model_file_name);
55
56 size_t num_input_nodes = m_session->GetInputCount();
57 m_input_node_names.resize(num_input_nodes);
58
59 for (std::size_t i = 0; i < num_input_nodes; i++) {
60 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
61 ATH_MSG_DEBUG("Input " << i << " : " << " name= " << input_name);
62 m_input_node_names[i] = input_name;
63
64 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
65 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
66 ONNXTensorElementDataType type = tensor_info.GetElementType();
67 ATH_MSG_DEBUG("Input " << i << " : " << " type= " << type);
68
69 m_input_node_dims = tensor_info.GetShape();
70 ATH_MSG_DEBUG("Input " << i << " : num_dims= " << m_input_node_dims.size());
71 for (std::size_t j = 0; j < m_input_node_dims.size(); j++) {
72 if (m_input_node_dims[j] < 0) m_input_node_dims[j] = 1;
73 ATH_MSG_DEBUG("Input " << i << " : dim " << j << "= " << m_input_node_dims[j]);
74 }
75 }
76
77 std::vector<int64_t> output_node_dims;
78 size_t num_output_nodes = m_session->GetOutputCount();
79 ATH_MSG_DEBUG("Have output nodes " << num_output_nodes);
80 m_output_node_names.resize(num_output_nodes);
81
82 for (std::size_t i = 0; i < num_output_nodes; i++) {
83 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
84 ATH_MSG_DEBUG("Output " << i << " : " << " name= " << output_name);
85 m_output_node_names[i] = output_name;
86
87 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
88 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
89 ONNXTensorElementDataType type = tensor_info.GetElementType();
90 ATH_MSG_DEBUG("Output " << i << " : " << " type= " << type);
91
92 output_node_dims = tensor_info.GetShape();
93 ATH_MSG_DEBUG("Output " << i << " : num_dims= " << output_node_dims.size());
94 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
95 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
96 ATH_MSG_DEBUG("Output" << i << " : dim " << j << "= " << output_node_dims[j]);
97 }
98 }
99
100
101 return StatusCode::SUCCESS;
102}
103
104StatusCode TrigADComboHypoTool::decide(Combo::LegDecisionsMap& passingLegs, const EventContext& context) const{
105
106 ATH_MSG_DEBUG("Size of passingLegs = " << passingLegs.size());
107
108 if (passingLegs.size() == 0) { // if no combinations passed, then exit
109 return StatusCode::SUCCESS;
110 }
111
112 std::vector<std::vector<Combo::LegDecision>> legDecisions;
113 ATH_CHECK(selectLegs(passingLegs, legDecisions));
114
115 // get the lists of each object
116 ATH_MSG_DEBUG("Have "<<passingLegs.size()<<" passing legs in AD");
117
118 std::vector<const xAOD::Jet*> input_jets;
119 std::map<const xAOD::Jet*, std::vector<Combo::LegDecision>> jet_decisions;
120 std::vector<const xAOD::Electron*> input_electrons;
121 std::map<const xAOD::Electron*, std::vector<Combo::LegDecision>> ele_decisions;
122 std::vector<const xAOD::Muon*> input_muons;
123 std::map<const xAOD::Muon*, std::vector<Combo::LegDecision>> muon_decisions;
124 std::vector<const xAOD::Photon*> input_photons;
125 std::map<const xAOD::Photon*, std::vector<Combo::LegDecision>> gam_decisions;
126 std::vector<const xAOD::TauJet*> input_taus;
127 std::map<const xAOD::TauJet*, std::vector<Combo::LegDecision>> taujet_decisions;
128 std::vector<const xAOD::TrigMissingET*> input_mets;
129 std::map<const xAOD::TrigMissingET*, std::vector<Combo::LegDecision>> met_decisions;
130
131 for(const auto &leg_decs : legDecisions){ // loop over each leg
132 for(const auto &dec_pair : leg_decs){ // loop over each object in a leg
133 const TrigCompositeUtils::Decision* decision(*(dec_pair.second));
134 std::vector<TrigCompositeUtils::LinkInfo<xAOD::JetContainer>> jet_feature_links = TrigCompositeUtils::findLinks<xAOD::JetContainer>(decision, TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
135 std::vector<TrigCompositeUtils::LinkInfo<xAOD::ElectronContainer>> ele_feature_links = TrigCompositeUtils::findLinks<xAOD::ElectronContainer>(decision, TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
136 std::vector<TrigCompositeUtils::LinkInfo<xAOD::MuonContainer>> muon_feature_links = TrigCompositeUtils::findLinks<xAOD::MuonContainer>(decision, TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
137 std::vector<TrigCompositeUtils::LinkInfo<xAOD::PhotonContainer>> gam_feature_links = TrigCompositeUtils::findLinks<xAOD::PhotonContainer>(decision, TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
138 std::vector<TrigCompositeUtils::LinkInfo<xAOD::TauJetContainer>> taujet_feature_links = TrigCompositeUtils::findLinks<xAOD::TauJetContainer>(decision, TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
139 std::vector<TrigCompositeUtils::LinkInfo<xAOD::TrigMissingETContainer>> met_feature_links = TrigCompositeUtils::findLinks<xAOD::TrigMissingETContainer>(decision, TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
140 if(jet_feature_links.size()==1){
141 const TrigCompositeUtils::LinkInfo<xAOD::JetContainer> jet_feature_link = jet_feature_links.at(0);
142 ATH_CHECK(jet_feature_link.isValid());
143 const xAOD::Jet* jet = *(jet_feature_link.link);
144 jet_decisions[jet].push_back(dec_pair);
145 }
146 if(ele_feature_links.size()==1){
147 const TrigCompositeUtils::LinkInfo<xAOD::ElectronContainer> ele_feature_link = ele_feature_links.at(0);
148 ATH_CHECK(ele_feature_link.isValid());
149 const xAOD::Electron* electron = *(ele_feature_link.link);
150 ele_decisions[electron].push_back(dec_pair);
151 }
152 if(muon_feature_links.size()==1){
153 const TrigCompositeUtils::LinkInfo<xAOD::MuonContainer> muon_feature_link = muon_feature_links.at(0);
154 ATH_CHECK(muon_feature_link.isValid());
155 const xAOD::Muon* muon = *(muon_feature_link.link);
156 muon_decisions[muon].push_back(dec_pair);
157 }
158 if(gam_feature_links.size()==1){
159 const TrigCompositeUtils::LinkInfo<xAOD::PhotonContainer> gam_feature_link = gam_feature_links.at(0);
160 ATH_CHECK(gam_feature_link.isValid());
161 const xAOD::Photon* photon = *(gam_feature_link.link);
162 gam_decisions[photon].push_back(dec_pair);
163 }
164 if(taujet_feature_links.size()==1){
165 const TrigCompositeUtils::LinkInfo<xAOD::TauJetContainer> taujet_feature_link = taujet_feature_links.at(0);
166 ATH_CHECK(taujet_feature_link.isValid());
167 const xAOD::TauJet* taujet = *(taujet_feature_link.link);
168 taujet_decisions[taujet].push_back(dec_pair);
169 }
170 if(met_feature_links.size()==1){
171 const TrigCompositeUtils::LinkInfo<xAOD::TrigMissingETContainer> met_feature_link = met_feature_links.at(0);
172 ATH_CHECK(met_feature_link.isValid());
173 const xAOD::TrigMissingET* met = *(met_feature_link.link);
174 met_decisions[met].push_back(dec_pair);
175 }
176 }
177 }
178
179 for(const auto &pair : jet_decisions){
180 input_jets.push_back(pair.first);
181 }
182 if(input_jets.size()>1){
183 std::sort(input_jets.begin(), input_jets.end(),
184 [](const auto a, const auto b){
185 return a->pt() > b->pt();
186 });
187 }
188
189 for(const auto &pair : ele_decisions){
190 input_electrons.push_back(pair.first);
191 }
192 if(input_electrons.size()>1){
193 std::sort(input_electrons.begin(), input_electrons.end(),
194 [](const auto a, const auto b){
195 return a->pt() > b->pt();
196 });
197 }
198
199 for(const auto &pair : muon_decisions){
200 input_muons.push_back(pair.first);
201 }
202 if(input_muons.size()>1){
203 std::sort(input_muons.begin(), input_muons.end(),
204 [](const auto a, const auto b){
205 return a->pt() > b->pt();
206 });
207 }
208
209 for(const auto &pair : gam_decisions){
210 input_photons.push_back(pair.first);
211 }
212 if(input_photons.size()>1){
213 std::sort(input_photons.begin(), input_photons.end(),
214 [](const auto a, const auto b){
215 return a->pt() > b->pt();
216 });
217 }
218
219 for(const auto &pair : taujet_decisions){
220 input_taus.push_back(pair.first);
221 }
222 if(input_taus.size()>1){
223 std::sort(input_taus.begin(), input_taus.end(),
224 [](const auto a, const auto b){
225 return a->pt() > b->pt();
226 });
227 }
228
229 for(const auto &pair : met_decisions){
230 input_mets.push_back(pair.first);
231 }
232
233 float outputScore = this->getAdScore(input_jets, input_muons, input_electrons, input_photons, input_taus, input_mets);
234
235 // Recording Data
236 if(!m_adScoreKey.empty()){
237 auto adScoreContainer = std::make_unique< xAOD::TrigCompositeContainer>();
238 auto adScoreContainerAux = std::make_unique< xAOD::TrigCompositeAuxContainer>();
239 adScoreContainer->setStore(adScoreContainerAux.get());
240
242 adScoreContainer->push_back(adScore);
243 adScore->setDetail( "adScore", outputScore );
244
246 ATH_CHECK( adScoreHandle.record( std::move( adScoreContainer ), std::move( adScoreContainerAux ) ) );
247 }
248
249 // Monitoring
250 if(m_monFlag){
251 auto monScore = Monitored::Scalar<float>("adScore", -1.0);
252 auto monGroup = Monitored::Group(m_monTool, monScore); // possible use in future
253 monScore = outputScore;
254 }
255
256 bool trigPass = (outputScore > m_adScoreThres);
257
258 if(!trigPass){
259 eraseFromLegDecisionsMap(passingLegs);
260 }
261
262 return StatusCode::SUCCESS;
263}
264
266 const std::vector<const xAOD::Jet*> &input_jets,
267 const std::vector<const xAOD::Muon*> &input_muons,
268 const std::vector<const xAOD::Electron*> &input_electrons,
269 const std::vector<const xAOD::Photon*> &input_photons,
270 const std::vector<const xAOD::TauJet*> &input_taus,
271 const std::vector<const xAOD::TrigMissingET*> &input_mets) const{
272
273 ATH_MSG_DEBUG( "Counting AD input objects in the event ... "
274 << "Jets: " << input_jets.size() << ", "
275 << "Muons: " << input_muons.size() << ", "
276 << "Electrons: " << input_electrons.size() << ", "
277 << "Photons: " << input_photons.size() << ", "
278 << "TauJets: " << input_taus.size() << ", "
279 << "METs: " << input_mets.size());
280
281 // pt1 eta1 phi1 pt2 eta2 phi2 ... for 6 jets, 3 electrons, 3 muons, 3 photons, and MET
282 unsigned int metind = (m_maxjs.value()+m_maxes.value()+m_maxms.value()+m_maxgs.value())*3;
283 std::vector<float> inputTensor;
284
285 unsigned int jet_count = 0;
286 for(const auto &jet : input_jets){
287 ATH_MSG_DEBUG( std::setprecision(3) << std::fixed
288 << "jet[" << jet_count << "] = ("
289 << jet->pt()/1000 << ", "
290 << jet->eta() << ", "
291 << jet->phi() << ", "
292 << jet->m()/1000 << ")");
293 if (jet_count<m_maxjs.value()) {
294 inputTensor.insert(inputTensor.end(), {static_cast<float>(jet->pt()/1000), static_cast<float>(jet->eta()), static_cast<float>(jet->phi())});
295 jet_count++;
296 }
297 }
298 inputTensor.insert(inputTensor.end(), 3*(m_maxjs.value()-jet_count), 0.);
299
300 unsigned int ele_count = 0;
301 for(const auto &ele : input_electrons){
302 ATH_MSG_DEBUG( std::setprecision(3) << std::fixed
303 << "ele[" << ele_count << "] = ("
304 << ele->pt()/1000 << ", "
305 << ele->eta() << ", "
306 << ele->phi() << ", "
307 << ele->m()/1000 << ")");
308 if (ele_count<m_maxes.value()) {
309 inputTensor.insert(inputTensor.end(), {static_cast<float>(ele->pt()/1000), static_cast<float>(ele->eta()), static_cast<float>(ele->phi())});
310 ele_count++;
311 }
312 }
313 inputTensor.insert(inputTensor.end(), 3*(m_maxes.value()-ele_count), 0.);
314
315 unsigned int muon_count = 0;
316 for(const auto &muon : input_muons){
317 ATH_MSG_DEBUG( std::setprecision(3) << std::fixed
318 << "muon[" << muon_count << "] = ("
319 << muon->pt()/1000 << ", "
320 << muon->eta() << ", "
321 << muon->phi() << ", "
322 << muon->m()/1000 << ")");
323 if (muon_count<m_maxms.value()) {
324 inputTensor.insert(inputTensor.end(), {static_cast<float>(muon->pt()/1000), static_cast<float>(muon->eta()), static_cast<float>(muon->phi())});
325 muon_count++;
326 }
327 }
328 inputTensor.insert(inputTensor.end(), 3*(m_maxms.value()-muon_count), 0.);
329
330 unsigned int gam_count = 0;
331 for(const auto &gam : input_photons){
332 ATH_MSG_DEBUG( std::setprecision(3) << std::fixed
333 << "gam[" << gam_count << "] = ("
334 << gam->pt()/1000 << ", "
335 << gam->eta() << ", "
336 << gam->phi() << ", "
337 << gam->m()/1000 << ")");
338 if (gam_count<m_maxgs.value()) {
339 inputTensor.insert(inputTensor.end(), {static_cast<float>(gam->pt()/1000), static_cast<float>(gam->eta()), static_cast<float>(gam->phi())});
340 gam_count++;
341 }
342 }
343 inputTensor.insert(inputTensor.end(), 3*(m_maxgs.value()-gam_count), 0.);
344
345 inputTensor.insert(inputTensor.end(), {0., 0., 0.});
346 for(const auto &met : input_mets){
347 ROOT::Math::XYVectorF metv(met->ex(),met->ey());
348 float met_phi = metv.phi();
349 float met_et = metv.r();
350
351 ATH_MSG_DEBUG( std::setprecision(3) << std::fixed
352 << "MET = ("
353 << met_et/1000 << ", "
354 << met_phi << ")");
355 inputTensor[metind] = met_et/1000;
356 inputTensor[metind+2] = met_phi;
357 }
358
359 ATH_MSG_DEBUG("inputTensor size = " << inputTensor.size());
360 if(msgLvl(MSG::DEBUG)){
361 for (unsigned int i=0; i<inputTensor.size(); i++){
362 ATH_MSG_DEBUG("inputTensor[" << i << "] = " << inputTensor[i]);
363 }
364 }
365
366 float outputScore = runInference(inputTensor);
367 ATH_MSG_DEBUG("Computed TrigADScore: " << outputScore);
368
369 return outputScore;
370}
371
372float TrigADComboHypoTool::runInference(std::vector<float> &tensor) const {
373
374 ATH_MSG_DEBUG("in TrigADComboHypoTool::runInference()");
375
376 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
377 int input_tensor_size = (m_maxjs.value()+m_maxes.value()+m_maxms.value()+m_maxgs.value()+1)*3;
378 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, tensor.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
379
380 // Ort::Session::Run is non-const.
381 // However, the onnx authors claim that it is safe to call from multiple threads:
382 // https://github.com/Microsoft/onnxruntime/issues/114
383 Ort::Session* session ATLAS_THREAD_SAFE = m_session.get();
384 auto output_tensors = session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
385
386 float *output_score_array = output_tensors.front().GetTensorMutableData<float>();
387 unsigned int output_size = output_tensors.front().GetTensorTypeAndShapeInfo().GetElementCount();
388
389 float output_score = 0.;
390 if(output_size!=1){
391 ATH_MSG_ERROR("Invalid output tensor size: " << output_size);
392 }else{
393 output_score = output_score_array[0];
394 }
395
396 return output_score;
397}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
static Double_t a
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define ATLAS_THREAD_SAFE
StatusCode selectLegs(const Combo::LegDecisionsMap &IDCombMap, std::vector< std::vector< Combo::LegDecision > > &leg_decisions) const
Creates the per-leg vectors of Decision objects starting from the initial LegDecision map,...
void eraseFromLegDecisionsMap(Combo::LegDecisionsMap &passingLegs) const
For when the tool rejects all combinations.
ComboHypoToolBase(const std::string &type, const std::string &name, const IInterface *parent)
Group of local monitoring quantities and retain correlation when filling histograms
Declare a monitored scalar variable.
StatusCode record(std::unique_ptr< T > data)
Record a const object to the store.
float runInference(std::vector< float > &tensor) const
Gaudi::Property< double > m_adScoreThres
Gaudi::Property< unsigned int > m_maxgs
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
TrigADComboHypoTool(const std::string &type, const std::string &name, const IInterface *parent)
SG::WriteHandleKey< xAOD::TrigCompositeContainer > m_adScoreKey
Gaudi::Property< std::string > m_modelFileName
Gaudi::Property< unsigned int > m_maxjs
std::vector< const char * > m_output_node_names
float getAdScore(const std::vector< const xAOD::Jet * > &input_jets, const std::vector< const xAOD::Muon * > &input_muons, const std::vector< const xAOD::Electron * > &input_electrons, const std::vector< const xAOD::Photon * > &input_photons, const std::vector< const xAOD::TauJet * > &input_taus, const std::vector< const xAOD::TrigMissingET * > &input_mets) const
Gaudi::Property< unsigned int > m_maxms
std::unique_ptr< Ort::Session > m_session
std::vector< const char * > m_input_node_names
Gaudi::Property< bool > m_monFlag
Gaudi::Property< unsigned int > m_maxes
virtual StatusCode initialize() override
std::vector< int64_t > m_input_node_dims
ToolHandle< GenericMonitoringTool > m_monTool
virtual StatusCode decide(Combo::LegDecisionsMap &passingLegs, const EventContext &context) const override
retrieves the decisions associated to this decId, make their combinations and apply the algorithm
STL class.
bool setDetail(const std::string &name, const TYPE &value)
Set an TYPE detail on the object.
const std::string & featureString()
void findLinks(const Decision *start, const std::string &linkName, std::vector< LinkInfo< T > > &links, unsigned int behaviour=TrigDefs::allFeaturesOfType, std::set< const xAOD::TrigComposite * > *fullyExploredFrom=nullptr)
search back the TC links for the object of type T linked to the one of TC (recursively) Populates pro...
static const unsigned int lastFeatureOfType
Run 3 "enum". Only return the final feature along each route through the navigation.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Jet_v1 Jet
Definition of the current "jet version".
TrigComposite_v1 TrigComposite
Declare the latest version of the class.
TauJet_v3 TauJet
Definition of the current "tau version".
TrigMissingET_v1 TrigMissingET
Define the most recent version of the TrigMissingET class.
Muon_v1 Muon
Reference the current persistent version:
Photon_v1 Photon
Definition of the current "egamma version".
Electron_v1 Electron
Definition of the current "egamma version".
Helper to keep a Decision object, ElementLink and ActiveState (with respect to some requested ChainGr...
Definition LinkInfo.h:22
ElementLink< T > link
Link to the feature.
Definition LinkInfo.h:55