22#include "TMVA/Reader.h"
34 float largestWeight = 0;
36 for (
const auto *vtx : *vertices) {
38 const auto& trkLinks=vtx->trackParticleLinks();
39 const size_t nTrackLinks=trkLinks.size();
40 for (
unsigned i=0;i<nTrackLinks;++i) {
42 if( vtx->trackWeights()[i] > largestWeight ){
43 vtxWithLargestWeight = vtx;
50 return vtxWithLargestWeight;
90 int nVars,
const std::vector<std::vector<float>>& input_data,
91 const std::shared_ptr<Ort::Session>& sessionHandle,
92 std::vector<int64_t> input_node_dims,
93 std::vector<const char*> input_node_names,
94 std::vector<const char*> output_node_names)
const {
98 std::vector<std::vector<float>> input_tensor_values_ = input_data;
101 size_t input_tensor_size = nVars;
102 std::vector<float> input_tensor_values(nVars);
103 input_tensor_values = input_tensor_values_[0];
107 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
109 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
110 memory_info, input_tensor_values.data(), input_tensor_size,
111 input_node_dims.data(), input_node_dims.size());
114 assert(input_tensor.IsTensor());
117 auto output_tensors =
118 sessionHandle->Run(Ort::RunOptions{
nullptr}, input_node_names.data(),
119 &input_tensor, input_node_names.size(),
120 output_node_names.data(), output_node_names.size());
123 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
127 float* floatarr = output_tensors[0].GetTensorMutableData<
float>();
129 int arrSize =
sizeof(*floatarr) /
sizeof(floatarr[0]);
136 std::tuple<std::vector<int64_t>, std::vector<const char*>>
138 const std::shared_ptr<Ort::Session>& sessionHandle,
139 Ort::AllocatorWithDefaultOptions& allocator) {
141 std::vector<int64_t> input_node_dims;
142 size_t num_input_nodes = sessionHandle->GetInputCount();
143 std::vector<const char*> input_node_names(num_input_nodes);
146 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
148 char* input_name = sessionHandle->GetInputNameAllocated(i, allocator).release();
150 input_node_names[i] = input_name;
153 Ort::TypeInfo type_info = sessionHandle->GetInputTypeInfo(i);
154 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
155 ONNXTensorElementDataType
type = tensor_info.GetElementType();
159 input_node_dims = tensor_info.GetShape();
160 ATH_MSG_DEBUG(
"Input "<<i<<
" : num_dims= "<<input_node_dims.size());
161 for (std::size_t j = 0; j < input_node_dims.size(); j++){
162 if(input_node_dims[j]<0){input_node_dims[j] =1;}
163 ATH_MSG_DEBUG(
"Input"<<i<<
" : dim "<<j<<
"= "<<input_node_dims[j]);
166 return std::make_tuple(input_node_dims, input_node_names);
170 std::tuple<std::vector<int64_t>, std::vector<const char*>>
172 const std::shared_ptr<Ort::Session>& sessionHandle,
173 Ort::AllocatorWithDefaultOptions& allocator) {
175 std::vector<int64_t> output_node_dims;
176 size_t num_output_nodes = sessionHandle->GetOutputCount();
177 std::vector<const char*> output_node_names(num_output_nodes);
180 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
182 char* output_name = sessionHandle->GetOutputNameAllocated(i, allocator).release();
184 output_node_names[i] = output_name;
186 Ort::TypeInfo type_info = sessionHandle->GetOutputTypeInfo(i);
187 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
188 ONNXTensorElementDataType
type = tensor_info.GetElementType();
192 output_node_dims = tensor_info.GetShape();
193 ATH_MSG_DEBUG(
"Output "<<i<<
" : num_dims= "<<output_node_dims.size());
194 for (std::size_t j = 0; j < output_node_dims.size(); j++){
195 if(output_node_dims[j]<0){output_node_dims[j] =1;}
196 ATH_MSG_DEBUG(
"Output"<<i<<
" : dim "<<j<<
"= "<<output_node_dims[j]);
199 return std::make_tuple(output_node_dims, output_node_names);
203 std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
205 const std::string& modelFilePath) {
211 Ort::SessionOptions sessionOptions;
212 sessionOptions.SetIntraOpNumThreads( 1 );
213 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
215 Ort::AllocatorWithDefaultOptions allocator;
217 std::shared_ptr<Ort::Session> sessionHandle = std::make_shared<Ort::Session>( env, modelFileName.c_str(), sessionOptions );
219 ATH_MSG_INFO(
"Created the ONNX Runtime session for model file = " << modelFileName);
220 return std::make_tuple(sessionHandle, allocator);
226 ATH_MSG_INFO(
"Initializing PhotonVertexSelectionTool...");
233 std::vector<std::string> var_names = {
234 "deltaZ := TMath::Min(abs(PrimaryVerticesAuxDyn.z-zCommon)/zCommonError,20)",
235 "deltaPhi := abs(deltaPhi(PrimaryVerticesAuxDyn.phi,egamma_phi))" ,
236 "logSumpt := log10(PrimaryVerticesAuxDyn.sumPt)" ,
237 "logSumpt2 := log10(PrimaryVerticesAuxDyn.sumPt2)"
239 auto *mva1 =
new TMVA::Reader(var_names,
"!Silent:Color");
241 m_mva1 = std::unique_ptr<TMVA::Reader>( mva1 );
243 auto mva2 = std::make_unique<TMVA::Reader>(var_names,
"!Silent:Color");
245 m_mva2 = std::unique_ptr<TMVA::Reader>( std::move(mva2) );
273#ifndef XAOD_STANDALONE
278 return StatusCode::SUCCESS;
285 const EventContext& ctx = Gaudi::Hive::currentContext();
324 deltaZ(*vertex) = std::abs((zCommon.first - vertex->z())/zCommon.second);
329 if(failType!=
nullptr)
331 return StatusCode::SUCCESS;
335 std::vector<std::pair<const xAOD::Vertex*, float>>
342 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
345 if (
getVertexImp( egammas, vertex, ignoreConv, noDecorate, vertexMLP, vtxCase, failType ).isSuccess()) {
348 if(vtxCasePtr!=
nullptr)
349 *vtxCasePtr = vtxCase;
350 if(failTypePtr!=
nullptr)
351 *failTypePtr = failType;
359 bool ignoreConv)
const
361 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
364 return getVertexImp( egammas, prime_vertex, ignoreConv,
false, vertexMLP, vtxcase, failType );
372 std::vector<std::pair<const xAOD::Vertex*, float>>& vertexMLP,
383 return StatusCode::FAILURE;
387 if (!ignoreConv && photons) {
389 if (prime_vertex !=
nullptr) {
392 vertexMLP.emplace_back(prime_vertex, 0.);
393 return StatusCode::SUCCESS;
398 ATH_MSG_VERBOSE(
"Returning hardest vertex. Fail detected (type="<< fail <<
")");
401 vertexMLP.emplace_back(prime_vertex, 10.);
402 return StatusCode::SUCCESS;
411 bool isConverted =
false;
415 if (!ignoreConv && photons) {
416 for (
const auto *
photon: *photons) {
420 return StatusCode::FAILURE;
432 TMVA::Reader *tmva_reader =
new TMVA::Reader();
436 tmva_reader =
m_mva1.get();
440 tmva_reader =
m_mva2.get();
454 std::vector<float> ONNXInputVector;
455 std::vector<std::vector<float>> onnx_input_tensor_values;
456 std::vector<float> TMVAInputVector;
458 float mlp = 0.0, mlp_max = -99999.0;
459 float doSkipByZSigmaScore = -9999.0;
461 float thresGoodVtxScore;
463 else{thresGoodVtxScore = mlp_max;}
469 onnx_input_tensor_values.clear();
472 float sumPt, sumPt2,
deltaPhi, deltaZ;
473 float log10_sumPt, log10_sumPt2;
475 sumPt = (sumPtA)(*vertex);
476 sumPt2 = (sumPt2A)(*vertex);
478 deltaZ = (deltaZA)(*vertex);
480 " sumPt2: " << sumPt2 <<
482 " deltaZ: " << deltaZ);
487 TMVAMethod =
"MLP method";
488 log10_sumPt =
static_cast<float>(log10(sumPt));
489 log10_sumPt2 =
static_cast<float>(log10(sumPt2));
490 TMVAInputVector = {deltaZ,
deltaPhi,log10_sumPt,log10_sumPt2};
496 ONNXInputVector = {sumPt2, sumPt,
deltaPhi, deltaZ};
497 for (
long unsigned int i = 0; i < ONNXInputVector.size(); i++) {
502 if (ONNXInputVector[i] != 0 && std::isinf(ONNXInputVector[i]) !=
true && std::isnan(ONNXInputVector[i]) !=
true){
503 ONNXInputVector[i] = log(std::abs(ONNXInputVector[i]));
506 ONNXInputVector[i] = log(std::abs(0.00000001));
509 onnx_input_tensor_values.push_back(ONNXInputVector);
514 mlp = tmva_reader->EvaluateMVA(TMVAInputVector, TMVAMethod);
529 " log(abs(sumPt2)): " << sumPt2 <<
531 " log(abs(deltaZ)): " << deltaZ);
532 ATH_MSG_VERBOSE(
"ONNX output, isConverted = " << isConverted <<
", mlp=" << mlp);
538 if ((isConverted && deltaZ > 15) || (!isConverted && deltaZ > 10)) {
539 mlp = doSkipByZSigmaScore;
544 vertexMLP.emplace_back(vertex, mlp);
549 prime_vertex = vertex;
556 if (mlp_max <= thresGoodVtxScore) {
557 ATH_MSG_DEBUG(
"No good vertex candidates from pointing, returning hardest vertex.");
564 ATH_MSG_VERBOSE(
"getVertex case "<< (
int)vtxCase <<
" exit code "<< (
int)fail);
565 return StatusCode::SUCCESS;
570 const std::pair<const xAOD::Vertex*, float> &b)
571 {
return a.second > b.second; }
576 if (photons ==
nullptr) {
577 ATH_MSG_WARNING(
"Passed nullptr photon container, returning nullptr vertex from getPrimaryVertexFromConv");
581 std::vector<const xAOD::Vertex*> vertices;
582 const xAOD::Vertex *conversionVertex =
nullptr, *primary =
nullptr;
584 size_t NumberOfTracks = 0;
590 for (
const auto *
photon: *photons) {
591 conversionVertex =
photon->vertex();
592 if (conversionVertex ==
nullptr)
continue;
595 for (
size_t i = 0; i < NumberOfTracks; ++i) {
598 if (gsfTp ==
nullptr)
continue;
603 if (tp ==
nullptr)
continue;
606 if (primary ==
nullptr)
continue;
610 if (std::find(vertices.begin(), vertices.end(), primary) == vertices.end()) {
611 vertices.push_back(primary);
618 if (!vertices.empty()) {
619 if (vertices.size() > 1)
620 ATH_MSG_WARNING(
"Photons associated to different vertices! Returning lead photon association.");
630 TLorentzVector v, v1;
638 cluster =
egamma->caloCluster();
639 if (cluster ==
nullptr) {
640 ATH_MSG_WARNING(
"No cluster associated to egamma, not adding to 4-vector.");
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)
virtual double e() const
energy
SG::ConstAccessor< T, ALLOC > ConstAccessor
Handle class for adding a decoration to an object.
bool isAvailable()
Test to see if this variable exists in the store, for the referenced object.
float phiBE(const unsigned layer) const
Get the phi in one layer of the EM Calo.
float etaBE(const unsigned layer) const
Get the eta in one layer of the EM Calo.
size_t nTrackParticles() const
Get the number of tracks associated with this vertex.
const TrackParticle * trackParticle(size_t i) const
Get the pointer to a given track that was used in vertex reco.
const std::vector< float > & trackWeights() const
Get all the track weights.
Select isolated Photons, Electrons and Muons.
const xAOD::Vertex * getVertexFromTrack(const xAOD::TrackParticle *track, const xAOD::VertexContainer *vertices)
std::string decorKeyFromKey(const std::string &key, const std::string &deflt)
Extract the decoration part of key.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
const xAOD::TrackParticle * getOriginalTrackParticleFromGSF(const xAOD::TrackParticle *trkPar)
Helper function for getting the "Original" Track Particle (i.e before GSF) via the GSF Track Particle...
const xAOD::Vertex * getHardestVertex(const xAOD::VertexContainer *vertices)
Return vertex with highest sum pT^2.
float getVertexSumPt(const xAOD::Vertex *vertex, int power=1, bool useAux=true)
Loop over track particles associated with vertex and return scalar sum of pT^power in GeV (from auxda...
TLorentzVector getVertexMomentum(const xAOD::Vertex *vertex, bool useAux=true, const std::string &derivationPrefix="")
Return vector sum of tracks associated with vertex (from auxdata if available and useAux = true)
bool passConvSelection(const xAOD::Photon *photon, float convPtCut=2e3)
Check if photon is converted, and tracks have Si hits and pass selection.
std::pair< float, float > getZCommonAndError(const xAOD::EventInfo *eventInfo, const xAOD::EgammaContainer *egammas, float convPtCut=2e3)
Return zCommon and zCommonError.
PhotonContainer_v1 PhotonContainer
Definition of the current "photon container version".
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
setSAddress setEtaMS setDirPhiMS setDirZMS setBarrelRadius setEndcapAlpha setEndcapRadius setInterceptInner setEtaMap setEtaBin setIsTgcFailure setDeltaPt deltaPhi
TrackParticle_v1 TrackParticle
Reference the current persistent version:
VertexContainer_v1 VertexContainer
Definition of the current "Vertex container version".
Vertex_v1 Vertex
Define the latest version of the vertex class.
Egamma_v1 Egamma
Definition of the current "egamma version".
EgammaContainer_v1 EgammaContainer
Definition of the current "egamma container version".