8#include "AthLinks/ElementLink.h"
23 declareInterface<DiTauToolBase > (
this);
32 ATH_MSG_INFO(
"Initializing DiTauOnnxDiscriminantTool" );
36 if (model_path.empty()) {
38 return StatusCode::FAILURE;
40 m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"OnnxUtil");
41 Ort::SessionOptions session_options;
42 session_options.SetIntraOpNumThreads(1);
43 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
44 session_options.DisableCpuMemArena();
46 return StatusCode::SUCCESS;
54 return StatusCode::SUCCESS;
64 omni_scoreDec(*xDitau) = score;
65 return StatusCode::SUCCESS;
74 omni_scoreDec(xDiTau) = score;
75 return StatusCode::SUCCESS;
79 std::vector<float> flattened;
80 flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
81 for (
const auto &inner : vec_2d) {
82 flattened.insert(flattened.end(), inner.begin(), inner.end());
88 std::vector<float> points;
89 points.reserve(track_features.size() * 2);
90 for (
const auto &track : track_features) {
91 points.push_back(track[0]);
92 points.push_back(track[1]);
98 std::vector<float> mask;
99 mask.reserve(track_features.size());
100 std::transform(track_features.begin(), track_features.end(), std::back_inserter(mask), [](
const auto &track) {
101 return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
107 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108 return Ort::Value::CreateTensor<float>(memory_info,
data.data(),
data.size(),shape.data(), shape.size());
112 std::vector<Ort::Value> input_tensors;
114 input_tensors.emplace_back(
create_tensor(inputs.input_features, inputs.input_features_shape));
115 input_tensors.emplace_back(
create_tensor(inputs.input_points, inputs.input_points_shape));
116 input_tensors.emplace_back(
create_tensor(inputs.input_mask, inputs.input_mask_shape));
117 input_tensors.emplace_back(
create_tensor(inputs.input_jet, inputs.input_jet_shape));
118 input_tensors.emplace_back(
create_tensor(inputs.input_time, inputs.input_time_shape));
120 std::vector<const char *> input_node_names;
124 std::vector<const char *> output_node_names;
128 auto output_tensors =
m_ort_session->Run(Ort::RunOptions{
nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
131 for (
size_t i = 0; i < output_tensors.size(); ++i) {
132 const auto &tensor = output_tensors[i];
133 const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
134 const float *
data = tensor.GetTensorData<
float>();
135 (i == 0 ? output.output_1 : output.output_2) = std::vector<float>(
data,
data +
length);
159 std::vector<float> jet_vars = {
160 R_max(ditau, ditauTrackingInfo, 0),
161 R_max(ditau, ditauTrackingInfo, 1),
162 R_tracks(ditau, ditauTrackingInfo, 1),
174 static_cast<float>(
n_track(ditau)),
176 std::vector<int64_t> jet_shape = {1,
static_cast<int64_t
>(jet_vars.size())};
179 std::vector<std::vector<float>> track_features(
m_maxTracks, std::vector<float>(11, 0.0f));
181 float jet_eta = ditau.
eta();
182 float jet_phi = ditau.
phi();
183 size_t num_tracks = std::min(
static_cast<size_t>(
m_maxTracks), vTauTracks.size());
185 for (
size_t i = 0; i < num_tracks; ++i) {
187 if (!trackLink.
isValid())
continue;
189 float track_eta = xTrack->
eta();
190 float track_phi = xTrack->
phi();
191 float delta_eta = track_eta - jet_eta;
192 float delta_phi = std::remainder(track_phi - jet_phi, 2 *
M_PI);
193 float delta_R = std::hypot(delta_eta, delta_phi);
194 float track_pt =
static_cast<float>(xTrack->
pt());
195 float pt_log = std::log(track_pt + 1e-8f);
197 float pt_ratio = track_pt / jet_pt;
198 float pt_ratio_log = (pt_ratio <= 1.0f) ? std::log(1.0f - pt_ratio + 1e-8f) : 0.0f;
199 float track_charge = xTrack->
charge();
201 track_features[i] = {
209 static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
210 static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
211 static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
215 std::vector<int64_t> track_shape = {1,
static_cast<int64_t
>(
m_maxTracks), 11};
222 {1, track_shape[1], 2},
226 std::move(jet_shape),
231 return output.output_1[1];
237 while (xDiTau.
subjetPt(nSubjet) > 0. ){
250 return xDiTau.
fCore(iSubjet);
254 return xDiTau.
subjetPt(iSubjet) / xDiTau.
pt();
267 if (subjetInfo.
subjet_p4.DeltaR(xTrack->
p4()) > Rmax) {
282 for (
int i = 0; i < 2; i++) {
285 R_sum += subjetInfo.
subjet_p4.DeltaR(xTrack->
p4()) * xTrack->
pt();
301 R_sum += subjetInfo.
subjet_p4.DeltaR(xTrack->
p4()) * xTrack->
pt();
311 TLorentzVector allCoreTracks_p4;
314 allCoreTracks_p4 += xTrack->
p4();
316 float mass = allCoreTracks_p4.M();
324 TLorentzVector allTracks_p4;
327 allTracks_p4 += xTrack->
p4();
329 float mass = allTracks_p4.M();
347 iso_pt += xTrack->
pt();
349 if( xDiTau.
pt() == 0.){
352 return iso_pt / xDiTau.
pt();
364 ATH_MSG_WARNING(
"Track " << (!trackLinksAcc.
isAvailable(xDiTau) ?
"DiTauJet.trackLinks" :
"DiTauJet.isoTrackLinks") <<
" links not available.");
365 return StatusCode::FAILURE;
369 float Rsubjet = R_subjetAcc(xDiTau);
370 float RCore = R_coreAcc(xDiTau);
378 std::vector<ElementLink<xAOD::TrackParticleContainer>> isoTrackLinks = xDiTau.
isoTrackLinks();
379 for (
const auto &trackLink: isoTrackLinks) {
380 if (!trackLink.isValid()) {
387 std::vector<ElementLink<xAOD::TrackParticleContainer>> trackLinks = xDiTau.
trackLinks();
388 for (
const auto &trackLink : trackLinks) {
389 if (!trackLink.isValid()) {
394 trackingInfo.
vTracks.push_back(xTrack);
397 for (
int i=0; i<nSubjets; ++i){
399 TLorentzVector subjet_p4 = TLorentzVector();
401 subjetTrackingInfo.
subjet_p4 = subjet_p4;
402 trackingInfo.
vSubjetInfo.push_back(subjetTrackingInfo);
404 for (
const auto track : trackingInfo.
vTracks) {
407 for (
int i=0; i<nSubjets; ++i){
408 float dRTrackSubjet = trackingInfo.
vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
409 if (dRTrackSubjet < Rsubjet && dRTrackSubjet < dRMin){
410 dRMin = dRTrackSubjet;
415 trackingInfo.
vSubjetInfo[inSubjet].vTracks.push_back(track);
419 for (
int i=0; i<nSubjets; ++i){
420 float ptLeadTrack = 0;
421 for (
const auto track : trackingInfo.
vSubjetInfo[i].vTracks){
422 if (track->pt() > ptLeadTrack){
423 ptLeadTrack = track->pt();
429 for (
int i=0; i<nSubjets; ++i){
430 for (
const auto track : trackingInfo.
vSubjetInfo[i].vTracks){
431 auto subjetTrackingInfo = trackingInfo.
vSubjetInfo[i];
432 if (subjetTrackingInfo.subjet_p4.DeltaR(track->p4()) < RCore){
433 trackingInfo.
vSubjetInfo[i].vCoreTracks.push_back(track);
438 for (
const auto track : trackingInfo.
vIsoTracks){
442 for (
int i=0; i<nSubjets; ++i){
443 float dRTrackSubjet = trackingInfo.
vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
444 if (dRTrackSubjet > Rsubjet && dRTrackSubjet < RIso && dRTrackSubjet < dRMin){
445 dRMin = dRTrackSubjet;
450 trackingInfo.
vSubjetInfo[inSubjet].vIsoTracks.push_back(track);
453 return StatusCode::SUCCESS;
#define ATH_MSG_WARNING(x)
char data[hepevt_bytes_allocation_ATLAS]
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
ElementLink implementation for ROOT usage.
bool isValid() const
Test to see if the link can be dereferenced.
Helper class to provide type-safe access to aux data.
bool isAvailable(const ELT &e) const
Test to see if this variable exists in the store.
const TrackParticleLinks_t & isoTrackLinks() const
virtual double eta() const
The pseudorapidity ( ) of the particle.
float fCore(unsigned int numSubjet) const
float subjetEta(unsigned int numSubjet) const
virtual double pt() const
The transverse momentum ( ) of the particle.
float subjetE(unsigned int numSubjet) const
float subjetPt(unsigned int numSubjet) const
float subjetPhi(unsigned int numSubjet) const
virtual double phi() const
The azimuthal angle ( ) of the particle.
const TrackParticleLinks_t & trackLinks() const
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
float d0() const
Returns the parameter.
virtual double pt() const override final
The transverse momentum ( ) of the particle.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
float charge() const
Returns the charge.
TrackParticle_v1 TrackParticle
Reference the current persistent version:
DiTauJet_v1 DiTauJet
Definition of the current version.