19 #include <Math/Vector2D.h>
20 #include <Math/Vector2Dfwd.h>
33 return StatusCode::FAILURE;
41 ATH_MSG_ERROR(
"Please make sure it exists in the ATLAS calibration area (https://atlas-groupdata.web.cern.ch/atlas-groupdata/), and provide a model file name relative to the root of the calibration area.");
43 return StatusCode::FAILURE;
47 Ort::SessionOptions session_options;
48 Ort::AllocatorWithDefaultOptions allocator;
49 session_options.SetIntraOpNumThreads(1);
50 session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
52 m_session = std::make_unique<Ort::Session>(
m_svc->env(), model_file_name.c_str(), session_options);
54 ATH_MSG_INFO(
"Created ONNX runtime session with model " << model_file_name);
56 size_t num_input_nodes =
m_session->GetInputCount();
59 for (std::size_t
i = 0;
i < num_input_nodes;
i++) {
60 char* input_name =
m_session->GetInputNameAllocated(
i, allocator).release();
64 Ort::TypeInfo type_info =
m_session->GetInputTypeInfo(
i);
65 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
66 ONNXTensorElementDataType
type = tensor_info.GetElementType();
77 std::vector<int64_t> output_node_dims;
78 size_t num_output_nodes =
m_session->GetOutputCount();
82 for (std::size_t
i = 0;
i < num_output_nodes;
i++) {
83 char* output_name =
m_session->GetOutputNameAllocated(
i, allocator).release();
84 ATH_MSG_DEBUG(
"Output " <<
i <<
" : " <<
" name= " << output_name);
87 Ort::TypeInfo type_info =
m_session->GetOutputTypeInfo(
i);
88 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
89 ONNXTensorElementDataType
type = tensor_info.GetElementType();
92 output_node_dims = tensor_info.GetShape();
93 ATH_MSG_DEBUG(
"Output " <<
i <<
" : num_dims= " << output_node_dims.size());
94 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
95 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
96 ATH_MSG_DEBUG(
"Output" <<
i <<
" : dim " << j <<
"= " << output_node_dims[j]);
101 return StatusCode::SUCCESS;
106 ATH_MSG_DEBUG(
"Size of passingLegs = " << passingLegs.size());
108 if (passingLegs.size() == 0) {
109 return StatusCode::SUCCESS;
112 std::vector<std::vector<Combo::LegDecision>> legDecisions;
116 ATH_MSG_DEBUG(
"Have "<<passingLegs.size()<<
" passing legs in AD");
118 std::vector<const xAOD::Jet*> input_jets;
119 std::map<const xAOD::Jet*, std::vector<Combo::LegDecision>> jet_decisions;
120 std::vector<const xAOD::Electron*> input_electrons;
121 std::map<const xAOD::Electron*, std::vector<Combo::LegDecision>> ele_decisions;
122 std::vector<const xAOD::Muon*> input_muons;
123 std::map<const xAOD::Muon*, std::vector<Combo::LegDecision>> muon_decisions;
124 std::vector<const xAOD::Photon*> input_photons;
125 std::map<const xAOD::Photon*, std::vector<Combo::LegDecision>> gam_decisions;
126 std::vector<const xAOD::TauJet*> input_taus;
127 std::map<const xAOD::TauJet*, std::vector<Combo::LegDecision>> taujet_decisions;
128 std::vector<const xAOD::TrigMissingET*> input_mets;
129 std::map<const xAOD::TrigMissingET*, std::vector<Combo::LegDecision>> met_decisions;
131 for(
const auto &leg_decs : legDecisions){
132 for(
const auto &dec_pair : leg_decs){
134 std::vector<TrigCompositeUtils::LinkInfo<xAOD::JetContainer>> jet_feature_links = TrigCompositeUtils::findLinks<xAOD::JetContainer>(decision,
TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
135 std::vector<TrigCompositeUtils::LinkInfo<xAOD::ElectronContainer>> ele_feature_links = TrigCompositeUtils::findLinks<xAOD::ElectronContainer>(decision,
TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
136 std::vector<TrigCompositeUtils::LinkInfo<xAOD::MuonContainer>> muon_feature_links = TrigCompositeUtils::findLinks<xAOD::MuonContainer>(decision,
TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
137 std::vector<TrigCompositeUtils::LinkInfo<xAOD::PhotonContainer>> gam_feature_links = TrigCompositeUtils::findLinks<xAOD::PhotonContainer>(decision,
TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
138 std::vector<TrigCompositeUtils::LinkInfo<xAOD::TauJetContainer>> taujet_feature_links = TrigCompositeUtils::findLinks<xAOD::TauJetContainer>(decision,
TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
139 std::vector<TrigCompositeUtils::LinkInfo<xAOD::TrigMissingETContainer>> met_feature_links = TrigCompositeUtils::findLinks<xAOD::TrigMissingETContainer>(decision,
TrigCompositeUtils::featureString(), TrigDefs::lastFeatureOfType);
140 if(jet_feature_links.size()==1){
144 jet_decisions[
jet].push_back(dec_pair);
146 if(ele_feature_links.size()==1){
150 ele_decisions[
electron].push_back(dec_pair);
152 if(muon_feature_links.size()==1){
156 muon_decisions[
muon].push_back(dec_pair);
158 if(gam_feature_links.size()==1){
162 gam_decisions[
photon].push_back(dec_pair);
164 if(taujet_feature_links.size()==1){
168 taujet_decisions[taujet].push_back(dec_pair);
170 if(met_feature_links.size()==1){
174 met_decisions[
met].push_back(dec_pair);
179 for(
const auto &pair : jet_decisions){
180 input_jets.push_back(pair.first);
182 if(input_jets.size()>1){
183 std::sort(input_jets.begin(), input_jets.end(),
184 [](
const auto a,
const auto b){
185 return a->pt() > b->pt();
189 for(
const auto &pair : ele_decisions){
190 input_electrons.push_back(pair.first);
192 if(input_electrons.size()>1){
193 std::sort(input_electrons.begin(), input_electrons.end(),
194 [](
const auto a,
const auto b){
195 return a->pt() > b->pt();
199 for(
const auto &pair : muon_decisions){
200 input_muons.push_back(pair.first);
202 if(input_muons.size()>1){
203 std::sort(input_muons.begin(), input_muons.end(),
204 [](
const auto a,
const auto b){
205 return a->pt() > b->pt();
209 for(
const auto &pair : gam_decisions){
210 input_photons.push_back(pair.first);
212 if(input_photons.size()>1){
213 std::sort(input_photons.begin(), input_photons.end(),
214 [](
const auto a,
const auto b){
215 return a->pt() > b->pt();
219 for(
const auto &pair : taujet_decisions){
220 input_taus.push_back(pair.first);
222 if(input_taus.size()>1){
223 std::sort(input_taus.begin(), input_taus.end(),
224 [](
const auto a,
const auto b){
225 return a->pt() > b->pt();
229 for(
const auto &pair : met_decisions){
230 input_mets.push_back(pair.first);
233 float outputScore = this->
getAdScore(input_jets, input_muons, input_electrons, input_photons, input_taus, input_mets);
237 auto adScoreContainer = std::make_unique< xAOD::TrigCompositeContainer>();
238 auto adScoreContainerAux = std::make_unique< xAOD::TrigCompositeAuxContainer>();
239 adScoreContainer->setStore(adScoreContainerAux.get());
242 adScoreContainer->push_back(adScore);
243 adScore->
setDetail(
"adScore", outputScore );
246 ATH_CHECK( adScoreHandle.
record( std::move( adScoreContainer ), std::move( adScoreContainerAux ) ) );
253 monScore = outputScore;
262 return StatusCode::SUCCESS;
266 const std::vector<const xAOD::Jet*> &input_jets,
267 const std::vector<const xAOD::Muon*> &input_muons,
268 const std::vector<const xAOD::Electron*> &input_electrons,
269 const std::vector<const xAOD::Photon*> &input_photons,
270 const std::vector<const xAOD::TauJet*> &input_taus,
271 const std::vector<const xAOD::TrigMissingET*> &input_mets)
const{
274 <<
"Jets: " << input_jets.size() <<
", "
275 <<
"Muons: " << input_muons.size() <<
", "
276 <<
"Electrons: " << input_electrons.size() <<
", "
277 <<
"Photons: " << input_photons.size() <<
", "
278 <<
"TauJets: " << input_taus.size() <<
", "
279 <<
"METs: " << input_mets.size());
283 std::vector<float> inputTensor;
285 unsigned int jet_count = 0;
286 for(
const auto &
jet : input_jets){
288 <<
"jet[" << jet_count <<
"] = ("
289 <<
jet->pt()/1000 <<
", "
290 <<
jet->eta() <<
", "
291 <<
jet->phi() <<
", "
292 <<
jet->m()/1000 <<
")");
293 if (jet_count<
m_maxjs.value()) {
294 inputTensor.insert(inputTensor.end(), {static_cast<float>(jet->pt()/1000), static_cast<float>(jet->eta()), static_cast<float>(jet->phi())});
298 inputTensor.insert(inputTensor.end(), 3*(
m_maxjs.value()-jet_count), 0.);
300 unsigned int ele_count = 0;
301 for(
const auto &ele : input_electrons){
303 <<
"ele[" << ele_count <<
"] = ("
304 << ele->pt()/1000 <<
", "
305 << ele->eta() <<
", "
306 << ele->phi() <<
", "
307 << ele->m()/1000 <<
")");
308 if (ele_count<
m_maxes.value()) {
309 inputTensor.insert(inputTensor.end(), {static_cast<float>(ele->pt()/1000), static_cast<float>(ele->eta()), static_cast<float>(ele->phi())});
313 inputTensor.insert(inputTensor.end(), 3*(
m_maxes.value()-ele_count), 0.);
315 unsigned int muon_count = 0;
316 for(
const auto &
muon : input_muons){
318 <<
"muon[" << muon_count <<
"] = ("
319 <<
muon->pt()/1000 <<
", "
320 <<
muon->eta() <<
", "
321 <<
muon->phi() <<
", "
322 <<
muon->m()/1000 <<
")");
323 if (muon_count<
m_maxms.value()) {
324 inputTensor.insert(inputTensor.end(), {static_cast<float>(muon->pt()/1000), static_cast<float>(muon->eta()), static_cast<float>(muon->phi())});
328 inputTensor.insert(inputTensor.end(), 3*(
m_maxms.value()-muon_count), 0.);
330 unsigned int gam_count = 0;
331 for(
const auto &gam : input_photons){
333 <<
"gam[" << gam_count <<
"] = ("
334 << gam->pt()/1000 <<
", "
335 << gam->eta() <<
", "
336 << gam->phi() <<
", "
337 << gam->m()/1000 <<
")");
338 if (gam_count<
m_maxgs.value()) {
339 inputTensor.insert(inputTensor.end(), {static_cast<float>(gam->pt()/1000), static_cast<float>(gam->eta()), static_cast<float>(gam->phi())});
343 inputTensor.insert(inputTensor.end(), 3*(
m_maxgs.value()-gam_count), 0.);
345 inputTensor.insert(inputTensor.end(), {0., 0., 0.});
346 for(
const auto &
met : input_mets){
347 ROOT::Math::XYVectorF metv(
met->ex(),
met->ey());
348 float met_phi = metv.phi();
349 float met_et = metv.r();
353 << met_et/1000 <<
", "
355 inputTensor[metind] = met_et/1000;
356 inputTensor[metind+2] = met_phi;
361 for (
unsigned int i=0;
i<inputTensor.size();
i++){
376 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
386 float *output_score_array = output_tensors.front().GetTensorMutableData<
float>();
387 unsigned int output_size = output_tensors.front().GetTensorTypeAndShapeInfo().GetElementCount();
389 float output_score = 0.;
391 ATH_MSG_ERROR(
"Invalid output tensor size: " << output_size);
393 output_score = output_score_array[0];