ATLAS Offline Software
Loading...
Searching...
No Matches
ools/Root/DiTauOnnxDiscriminantTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7// Core include(s):
10
11using TrackParticleLinks_t = std::vector<ElementLink<xAOD::TrackParticleContainer>>;
12
13using namespace DiTauRecTools;
14
15//=================================PUBLIC-PART==================================
16//______________________________________________________________________________
18 : AsgTool(name)
19{
20}
21
22//______________________________________________________________________________
24
25//______________________________________________________________________________
27{
28 ATH_MSG_INFO( "Initializing DiTauOnnxDiscriminantTool" );
29 ATH_MSG_INFO( "onnxModelPath: " << m_onnxModelPath );
30 std::string model_path = PathResolverFindCalibFile(m_onnxModelPath);
31 if (model_path.empty()) {
32 ATH_MSG_ERROR("Could not find model file: " << m_onnxModelPath);
33 return StatusCode::FAILURE;
34 }
35 m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "OnnxUtil");
36 Ort::SessionOptions session_options;
37 session_options.SetIntraOpNumThreads(1);
38 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
39 m_ort_session = std::make_unique<Ort::Session>(*m_ort_env, model_path.c_str(), session_options);
40
41 ATH_CHECK( m_ditauContainerKey.initialize() );
42 ATH_CHECK( m_ditau_pt_DecorKey.initialize() );
43 ATH_CHECK( m_f_core_lead_DecorKey.initialize() );
45 ATH_CHECK( m_f_subjet_subl_DecorKey.initialize() );
46 ATH_CHECK( m_f_subjets_DecorKey.initialize() );
47 ATH_CHECK( m_R_max_lead_DecorKey.initialize() );
48 ATH_CHECK( m_R_max_sublead_DecorKey.initialize() );
49 ATH_CHECK( m_n_track_DecorKey.initialize() );
50 ATH_CHECK( m_R_track_all_DecorKey.initialize() );
51 ATH_CHECK( m_R_isotrack_DecorKey.initialize() );
53 ATH_CHECK( m_M_core_lead_DecorKey.initialize() );
55 ATH_CHECK( m_M_track_lead_DecorKey.initialize() );
58 ATH_CHECK( m_f_isotracks_DecorKey.initialize() );
59
60 return StatusCode::SUCCESS;
61}
62
64{
65 const static SG::Decorator<float> omni_scoreDec("omni_score");
66 ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
67 float score = GetDiTauObjOnnxScore(xDiTau);
68 ATH_MSG_DEBUG("DiTau ID score: " << score);
69 omni_scoreDec(xDiTau) = score;
70 return StatusCode::SUCCESS;
71}
72
73float DiTauOnnxDiscriminantTool::nan_to_num(float value, float nan_replacement = 0.0f, float posinf_replacement = 0.0f, float neginf_replacement = 0.0f) const{
74 if (std::isnan(value))
75 return nan_replacement;
76 if (value == std::numeric_limits<float>::infinity())
77 return posinf_replacement;
78 if (value == -std::numeric_limits<float>::infinity())
79 return neginf_replacement;
80 return value;
81 }
82
83std::vector<float> DiTauOnnxDiscriminantTool::flatten(const std::vector<std::vector<float>> &vec_2d) const{
84 std::vector<float> flattened;
85 flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
86 for (const auto &inner : vec_2d) {
87 flattened.insert(flattened.end(), inner.begin(), inner.end());
88 }
89 return flattened;
90}
91
92std::vector<float> DiTauOnnxDiscriminantTool::extract_points(const std::vector<std::vector<float>> &track_features) const{
93 std::vector<float> points;
94 points.reserve(track_features.size() * 2);
95 for (const auto &track : track_features) {
96 points.push_back(track[0]); // delta_eta
97 points.push_back(track[1]); // delta_phi
98 }
99 return points;
100}
101
102std::vector<float> DiTauOnnxDiscriminantTool::create_mask(const std::vector<std::vector<float>> &track_features) const{
103 std::vector<float> mask;
104 mask.reserve(track_features.size());
105 std::transform(track_features.begin(), track_features.end(), std::back_inserter(mask), [](const auto &track) {
106 return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
107 });
108 return mask;
109}
110
111Ort::Value DiTauOnnxDiscriminantTool::create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const{
112 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
113 return Ort::Value::CreateTensor<float>(memory_info, data.data(), data.size(),shape.data(), shape.size());
114}
115
117 std::vector<Ort::Value> input_tensors;
118 input_tensors.reserve(m_input_node_names.size());
119 input_tensors.emplace_back(create_tensor(inputs.input_features, inputs.input_features_shape));
120 input_tensors.emplace_back(create_tensor(inputs.input_points, inputs.input_points_shape));
121 input_tensors.emplace_back(create_tensor(inputs.input_mask, inputs.input_mask_shape));
122 input_tensors.emplace_back(create_tensor(inputs.input_jet, inputs.input_jet_shape));
123 input_tensors.emplace_back(create_tensor(inputs.input_time, inputs.input_time_shape));
124
125 std::vector<const char *> input_node_names;
126 input_node_names.reserve(m_input_node_names.size());
127 std::transform(m_input_node_names.begin(), m_input_node_names.end(), std::back_inserter(input_node_names), [](const std::string &name) { return name.c_str(); });
128
129 std::vector<const char *> output_node_names;
130 output_node_names.reserve(m_output_node_names.size());
131 std::transform(m_output_node_names.begin(), m_output_node_names.end(), std::back_inserter(output_node_names), [](const std::string &name) { return name.c_str(); });
132
133 auto output_tensors = m_ort_session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
134
136 for (size_t i = 0; i < output_tensors.size(); ++i) {
137 const auto &tensor = output_tensors[i];
138 const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
139 const float *data = tensor.GetTensorData<float>();
140 (i == 0 ? output.output_1 : output.output_2) = std::vector<float>(data, data + length);
141 }
142 return output;
143}
144
146
147 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> ditau_ptDec(m_ditau_pt_DecorKey);
148 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> f_core_leadDec(m_f_core_lead_DecorKey);
149 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> f_core_sublDec(m_f_core_sublead_DecorKey);
150 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> f_subjet_sublDec(m_f_subjet_subl_DecorKey);
151 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> f_subjetDec(m_f_subjets_DecorKey);
152 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> R_max_leadDec(m_R_max_lead_DecorKey);
153 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> R_max_subleadDec(m_R_max_sublead_DecorKey);
154 SG::ReadDecorHandle<xAOD::DiTauJetContainer,int> n_trackDec(m_n_track_DecorKey);
155 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> R_isotracDec(m_R_isotrack_DecorKey);
156 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> R_tracks_sublDec(m_R_track_sublead_DecorKey);
157 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> M_core_leadDec(m_M_core_lead_DecorKey);
158 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> M_core_sublDec(m_M_core_sublead_DecorKey);
159 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> M_tracks_leadDec(m_M_track_lead_DecorKey);
160 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> d0_leadtrack_leadDec(m_d0_leadtrack_lead_DecorKey);
161 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> d0_leadtrack_sublDec(m_d0_leadtrack_sublead_DecorKey);
162 SG::ReadDecorHandle<xAOD::DiTauJetContainer,float> f_isotracks_Dec(m_f_isotracks_DecorKey);
163 // Accessors for reading the necessary features from the xAOD::TrackParticle object
164 static const SG::ConstAccessor< uint8_t > numberOfInrmstPxlLyrHitsAcc ("numberOfInnermostPixelLayerHits");
165 static const SG::ConstAccessor< uint8_t > numberOfPixelHitsAcc ("numberOfPixelHits");
166 static const SG::ConstAccessor< uint8_t > numberOfSCTHitsAcc ("numberOfSCTHits");
167 static const SG::ConstAccessor< float > z0Acc ("z0");
168 static const SG::ConstAccessor< float > d0Acc ("d0");
169
170
171 // Input features for Ditau tagger ONNX model
172 std::vector<float> jet_vars = {
173 R_max_leadDec (ditau),
174 R_max_subleadDec (ditau),
175 R_tracks_sublDec (ditau),
176 R_isotracDec (ditau),
177 d0_leadtrack_leadDec (ditau),
178 d0_leadtrack_sublDec (ditau),
179 f_core_leadDec (ditau),
180 f_core_sublDec (ditau),
181 f_subjet_sublDec (ditau),
182 f_subjetDec (ditau),
183 f_isotracks_Dec (ditau),
184 M_core_leadDec (ditau),
185 M_core_sublDec (ditau),
186 M_tracks_leadDec (ditau),
187 static_cast<float>( n_trackDec (ditau)),
188 };
189 std::vector<int64_t> jet_shape = {1, static_cast<int64_t>(jet_vars.size())};
190
191 const TrackParticleLinks_t &vTauTracks = ditau.trackLinks();
192 std::vector<std::vector<float>> track_features(m_maxTracks, std::vector<float>(11, 0.0f));
193
194 float jet_eta = ditau.eta();
195 float jet_phi = ditau.phi();
196 size_t num_tracks = std::min(static_cast<size_t>(m_maxTracks), vTauTracks.size());
197
198 for (size_t i = 0; i < num_tracks; ++i) {
199 const ElementLink<xAOD::TrackParticleContainer> &trackLink = vTauTracks[i];
200 if (!trackLink.isValid()) continue;
201 const xAOD::TrackParticle *xTrack = *trackLink;
202 float track_eta = xTrack->eta();
203 float track_phi = xTrack->phi();
204 float delta_eta = track_eta - jet_eta;
205 float delta_phi = std::remainder(track_phi - jet_phi, 2 * M_PI);
206 float delta_R = std::hypot(delta_eta, delta_phi);
207 float track_pt = static_cast<float>(xTrack->pt());
208 float pt_log = std::log(track_pt + 1e-8f);
209 float jet_pt = ditau_ptDec(ditau);
210 float pt_ratio = track_pt / jet_pt;
211 float pt_ratio_log = std::log(1.0f - pt_ratio + 1e-8f);
212 float track_charge = xTrack->charge();
213 float pt_ratio_log_nan_less = nan_to_num(pt_ratio_log, 0.0f, 0.0f, 0.0f);
214
215 track_features[i] = {
216 delta_eta,
217 delta_phi,
218 pt_log,
219 d0Acc(*xTrack),
220 pt_ratio_log_nan_less,
221 z0Acc(*xTrack),
222 delta_R,
223 static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
224 static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
225 static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
226 track_charge
227 };
228 }
229 std::vector<int64_t> track_shape = {1, static_cast<int64_t>(m_maxTracks), 11};
230
231 // Actual ONNX inference
233 flatten(track_features),
234 track_shape,
235 extract_points(track_features),
236 {1, track_shape[1], 2},
237 create_mask(track_features),
238 {1, track_shape[1]},
239 std::move(jet_vars),
240 std::move(jet_shape),
241 {0.0f},
242 {1, 1}
243 };
244 auto output = run_inference(inputs);
245 return output.output_1[1];
246}
247
248
#define M_PI
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
Handle class for reading a decoration on an object.
double length(const pvec &v)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
DiTauOnnxDiscriminantTool(const std::string &type, const std::string &name, const IInterface *parent)
InferenceOutput run_inference(OnnxInputs &inputs) const
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
std::vector< float > create_mask(const std::vector< std::vector< float > > &track_features) const
const std::vector< std::string > m_output_node_names
std::unique_ptr< Ort::Session > m_ort_session
std::vector< float > extract_points(const std::vector< std::vector< float > > &track_features) const
std::vector< float > flatten(const std::vector< std::vector< float > > &vec_2d) const
virtual StatusCode execute(DiTauCandidateData *data, const EventContext &ctx) const override
Execute - called for each Ditau candidate.
virtual StatusCode initialize() override
Tool initializer.
const std::vector< std::string > m_input_node_names
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
virtual ~DiTauOnnxDiscriminantTool()
Gaudi::Property< std::string > m_onnxModelPath
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjet_subl_DecorKey
SG::ReadHandleKey< xAOD::DiTauJetContainer > m_ditauContainerKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_track_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_ditau_pt_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_isotracks_DecorKey
float nan_to_num(float value, float nan_replacement, float posinf_replacement, float neginf_replacement) const
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_isotrack_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_all_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjets_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_n_track_DecorKey
Helper class to provide type-safe access to aux data.
Definition Decorator.h:59
virtual double eta() const
The pseudorapidity ( ) of the particle.
virtual double phi() const
The azimuthal angle ( ) of the particle.
const TrackParticleLinks_t & trackLinks() const
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
virtual double pt() const override final
The transverse momentum ( ) of the particle.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
float charge() const
Returns the charge.
Implementation of boosted di-tau ID.
bool pt_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
delta_phi(phi1, phi2)
Definition eFEXNTuple.py:14
delta_R(eta1, phi1, eta2, phi2)
Definition eFEXNTuple.py:20
output
Definition merge.py:16
TrackParticle_v1 TrackParticle
Reference the current persistent version:
DiTauJet_v1 DiTauJet
Definition of the current version.
Definition DiTauJet.h:17
std::vector< ElementLink< xAOD::TrackParticleContainer > > TrackParticleLinks_t